xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 8f6a67632a70b757e59067646226bcaacd9c5bd7)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineLoopInfo.h"
29 #include "llvm/CodeGen/MachineLoopUtils.h"
30 #include "llvm/CodeGen/MachineRegisterInfo.h"
31 #include "llvm/CodeGen/Passes.h"
32 #include "llvm/CodeGen/ReachingDefAnalysis.h"
33 #include "llvm/MC/MCInstrDesc.h"
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "arm-low-overhead-loops"
38 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
39 
40 namespace {
41 
42   struct PredicatedMI {
43     MachineInstr *MI = nullptr;
44     SetVector<MachineInstr*> Predicates;
45 
46   public:
47     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
48     MI(I) {
49       Predicates.insert(Preds.begin(), Preds.end());
50     }
51   };
52 
53   // Represent a VPT block, a list of instructions that begins with a VPST and
54   // has a maximum of four proceeding instructions. All instructions within the
55   // block are predicated upon the vpr and we allow instructions to define the
56   // vpr within in the block too.
57   class VPTBlock {
58     std::unique_ptr<PredicatedMI> VPST;
59     PredicatedMI *Divergent = nullptr;
60     SmallVector<PredicatedMI, 4> Insts;
61 
62   public:
63     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
64       VPST = std::make_unique<PredicatedMI>(MI, Preds);
65     }
66 
67     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
68       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
69       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
70         Divergent = &Insts.back();
71         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
72       }
73       Insts.emplace_back(MI, Preds);
74       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
75     }
76 
77     // Have we found an instruction within the block which defines the vpr? If
78     // so, not all the instructions in the block will have the same predicate.
79     bool HasNonUniformPredicate() const {
80       return Divergent != nullptr;
81     }
82 
83     // Is the given instruction part of the predicate set controlling the entry
84     // to the block.
85     bool IsPredicatedOn(MachineInstr *MI) const {
86       return VPST->Predicates.count(MI);
87     }
88 
89     // Is the given instruction the only predicate which controls the entry to
90     // the block.
91     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
92       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
93     }
94 
95     unsigned size() const { return Insts.size(); }
96     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
97     MachineInstr *getVPST() const { return VPST->MI; }
98     PredicatedMI *getDivergent() const { return Divergent; }
99   };
100 
101   struct LowOverheadLoop {
102 
103     MachineLoop *ML = nullptr;
104     MachineFunction *MF = nullptr;
105     MachineInstr *InsertPt = nullptr;
106     MachineInstr *Start = nullptr;
107     MachineInstr *Dec = nullptr;
108     MachineInstr *End = nullptr;
109     MachineInstr *VCTP = nullptr;
110     VPTBlock *CurrentBlock = nullptr;
111     SetVector<MachineInstr*> CurrentPredicate;
112     SmallVector<VPTBlock, 4> VPTBlocks;
113     bool Revert = false;
114     bool CannotTailPredicate = false;
115 
116     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
117       MF = ML->getHeader()->getParent();
118     }
119 
120     bool RecordVPTBlocks(MachineInstr *MI);
121 
122     // If this is an MVE instruction, check that we know how to use tail
123     // predication with it.
124     void AnalyseMVEInst(MachineInstr *MI) {
125       if (CannotTailPredicate)
126         return;
127 
128       if (!RecordVPTBlocks(MI)) {
129         CannotTailPredicate = true;
130         return;
131       }
132 
133       const MCInstrDesc &MCID = MI->getDesc();
134       uint64_t Flags = MCID.TSFlags;
135       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
136         return;
137 
138       if ((Flags & ARMII::ValidForTailPredication) == 0) {
139         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
140         CannotTailPredicate = true;
141       }
142     }
143 
144     bool IsTailPredicationLegal() const {
145       // For now, let's keep things really simple and only support a single
146       // block for tail predication.
147       return !Revert && FoundAllComponents() && VCTP &&
148              !CannotTailPredicate && ML->getNumBlocks() == 1;
149     }
150 
151     bool ValidateTailPredicate(MachineInstr *StartInsertPt,
152                                ReachingDefAnalysis *RDA,
153                                MachineLoopInfo *MLI);
154 
155     // Is it safe to define LR with DLS/WLS?
156     // LR can be defined if it is the operand to start, because it's the same
157     // value, or if it's going to be equivalent to the operand to Start.
158     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
159 
160     // Check the branch targets are within range and we satisfy our
161     // restrictions.
162     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
163                        MachineLoopInfo *MLI);
164 
165     bool FoundAllComponents() const {
166       return Start && Dec && End;
167     }
168 
169     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
170 
171     // Return the loop iteration count, or the number of elements if we're tail
172     // predicating.
173     MachineOperand &getCount() {
174       return IsTailPredicationLegal() ?
175         VCTP->getOperand(1) : Start->getOperand(0);
176     }
177 
178     unsigned getStartOpcode() const {
179       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
180       if (!IsTailPredicationLegal())
181         return IsDo ? ARM::t2DLS : ARM::t2WLS;
182 
183       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
184     }
185 
186     void dump() const {
187       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
188       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
189       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
190       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
191       if (!FoundAllComponents())
192         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
193       else if (!(Start && Dec && End))
194         dbgs() << "ARM Loops: Failed to find all loop components.\n";
195     }
196   };
197 
198   class ARMLowOverheadLoops : public MachineFunctionPass {
199     MachineFunction           *MF = nullptr;
200     MachineLoopInfo           *MLI = nullptr;
201     ReachingDefAnalysis       *RDA = nullptr;
202     const ARMBaseInstrInfo    *TII = nullptr;
203     MachineRegisterInfo       *MRI = nullptr;
204     const TargetRegisterInfo  *TRI = nullptr;
205     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
206 
207   public:
208     static char ID;
209 
210     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
211 
212     void getAnalysisUsage(AnalysisUsage &AU) const override {
213       AU.setPreservesCFG();
214       AU.addRequired<MachineLoopInfo>();
215       AU.addRequired<ReachingDefAnalysis>();
216       MachineFunctionPass::getAnalysisUsage(AU);
217     }
218 
219     bool runOnMachineFunction(MachineFunction &MF) override;
220 
221     MachineFunctionProperties getRequiredProperties() const override {
222       return MachineFunctionProperties().set(
223           MachineFunctionProperties::Property::NoVRegs).set(
224           MachineFunctionProperties::Property::TracksLiveness);
225     }
226 
227     StringRef getPassName() const override {
228       return ARM_LOW_OVERHEAD_LOOPS_NAME;
229     }
230 
231   private:
232     bool ProcessLoop(MachineLoop *ML);
233 
234     bool RevertNonLoops();
235 
236     void RevertWhile(MachineInstr *MI) const;
237 
238     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
239 
240     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
241 
242     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
243 
244     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
245 
246     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
247 
248     void Expand(LowOverheadLoop &LoLoop);
249 
250   };
251 }
252 
253 char ARMLowOverheadLoops::ID = 0;
254 
255 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
256                 false, false)
257 
258 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
259   // We can define LR because LR already contains the same value.
260   if (Start->getOperand(0).getReg() == ARM::LR)
261     return Start;
262 
263   unsigned CountReg = Start->getOperand(0).getReg();
264   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
265     return MI->getOpcode() == ARM::tMOVr &&
266            MI->getOperand(0).getReg() == ARM::LR &&
267            MI->getOperand(1).getReg() == CountReg &&
268            MI->getOperand(2).getImm() == ARMCC::AL;
269    };
270 
271   MachineBasicBlock *MBB = Start->getParent();
272 
273   // Find an insertion point:
274   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
275   //   to Count before Start, we can insert at that mov.
276   if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR))
277     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
278       return LRDef;
279 
280   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
281   //   to Count after Start, we can insert at that mov.
282   if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR))
283     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
284       return LRDef;
285 
286   // We've found no suitable LR def and Start doesn't use LR directly. Can we
287   // just define LR anyway?
288   if (!RDA->isRegUsedAfter(Start, ARM::LR))
289     return Start;
290 
291   return nullptr;
292 }
293 
294 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must
295 // not define a register that is used by any instructions, after and including,
296 // 'To'. These instructions also must not redefine any of Froms operands.
297 template<typename Iterator>
298 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) {
299   SmallSet<int, 2> Defs;
300   // First check that From would compute the same value if moved.
301   for (auto &MO : From->operands()) {
302     if (!MO.isReg() || MO.isUndef() || !MO.getReg())
303       continue;
304     if (MO.isDef())
305       Defs.insert(MO.getReg());
306     else if (!RDA->hasSameReachingDef(From, To, MO.getReg()))
307       return false;
308   }
309 
310   // Now walk checking that the rest of the instructions will compute the same
311   // value.
312   for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) {
313     for (auto &MO : I->operands())
314       if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg()))
315         return false;
316   }
317   return true;
318 }
319 
320 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
321 		                            ReachingDefAnalysis *RDA,
322 		                            MachineLoopInfo *MLI) {
323   // All predication within the loop should be based on vctp. If the block
324   // isn't predicated on entry, check whether the vctp is within the block
325   // and that all other instructions are then predicated on it.
326   for (auto &Block : VPTBlocks) {
327     if (Block.IsPredicatedOn(VCTP))
328       continue;
329     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI))
330       return false;
331     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
332     for (auto &PredMI : Insts) {
333       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
334         continue;
335       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
336                         << " - which is predicated on:\n";
337                         for (auto *MI : PredMI.Predicates)
338                           dbgs() << "   - " << *MI;
339                  );
340       return false;
341     }
342   }
343 
344   // For tail predication, we need to provide the number of elements, instead
345   // of the iteration count, to the loop start instruction. The number of
346   // elements is provided to the vctp instruction, so we need to check that
347   // we can use this register at InsertPt.
348   Register NumElements = VCTP->getOperand(1).getReg();
349 
350   // If the register is defined within loop, then we can't perform TP.
351   // TODO: Check whether this is just a mov of a register that would be
352   // available.
353   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
354     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
355     return false;
356   }
357 
358   // The element count register maybe defined after InsertPt, in which case we
359   // need to try to move either InsertPt or the def so that the [w|d]lstp can
360   // use the value.
361   MachineBasicBlock *InsertBB = InsertPt->getParent();
362   if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) {
363     if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) {
364       if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) {
365         ElemDef->removeFromParent();
366         InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef);
367         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
368                    << *ElemDef);
369       } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) {
370         InsertPt->removeFromParent();
371         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt);
372         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
373       } else
374         return false;
375     }
376   }
377 
378   // Especially in the case of while loops, InsertBB may not be the
379   // preheader, so we need to check that the register isn't redefined
380   // before entering the loop.
381   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
382                                       Register NumElements) {
383     // NumElements is redefined in this block.
384     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
385       return true;
386 
387     // Don't continue searching up through multiple predecessors.
388     if (MBB->pred_size() > 1)
389       return true;
390 
391     return false;
392   };
393 
394   // First, find the block that looks like the preheader.
395   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
396   if (!MBB)
397     return false;
398 
399   // Then search backwards for a def, until we get to InsertBB.
400   while (MBB != InsertBB) {
401     if (CannotProvideElements(MBB, NumElements))
402       return false;
403     MBB = *MBB->pred_begin();
404   }
405 
406   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
407   return true;
408 }
409 
410 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
411                                     ReachingDefAnalysis *RDA,
412                                     MachineLoopInfo *MLI) {
413   if (Revert)
414     return;
415 
416   if (!End->getOperand(1).isMBB())
417     report_fatal_error("Expected LoopEnd to target basic block");
418 
419   // TODO Maybe there's cases where the target doesn't have to be the header,
420   // but for now be safe and revert.
421   if (End->getOperand(1).getMBB() != ML->getHeader()) {
422     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
423     Revert = true;
424     return;
425   }
426 
427   // The WLS and LE instructions have 12-bits for the label offset. WLS
428   // requires a positive offset, while LE uses negative.
429   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
430       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
431     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
432     Revert = true;
433     return;
434   }
435 
436   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
437       (BBUtils->getOffsetOf(Start) >
438        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
439        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
440     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
441     Revert = true;
442     return;
443   }
444 
445   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
446   if (!InsertPt) {
447     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
448     Revert = true;
449     return;
450   } else
451     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
452 
453   if (!IsTailPredicationLegal()) {
454     LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n");
455     return;
456   }
457 
458   assert(ML->getBlocks().size() == 1 &&
459          "Shouldn't be processing a loop with more than one block");
460   CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI);
461   LLVM_DEBUG(if (CannotTailPredicate)
462              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
463 }
464 
465 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) {
466   // Only support a single vctp.
467   if (isVCTP(MI) && VCTP)
468     return false;
469 
470   // Start a new vpt block when we discover a vpt.
471   if (MI->getOpcode() == ARM::MVE_VPST) {
472     VPTBlocks.emplace_back(MI, CurrentPredicate);
473     CurrentBlock = &VPTBlocks.back();
474     return true;
475   }
476 
477   if (isVCTP(MI))
478     VCTP = MI;
479 
480   unsigned VPROpNum = MI->getNumOperands() - 1;
481   bool IsUse = false;
482   if (MI->getOperand(VPROpNum).isReg() &&
483       MI->getOperand(VPROpNum).getReg() == ARM::VPR &&
484       MI->getOperand(VPROpNum).isUse()) {
485     // If this instruction is predicated by VPR, it will be its last
486     // operand.  Also check that it's only 'Then' predicated.
487     if (!MI->getOperand(VPROpNum-1).isImm() ||
488         MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) {
489       LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: "
490                  << *MI);
491       return false;
492     }
493     CurrentBlock->addInst(MI, CurrentPredicate);
494     IsUse = true;
495   }
496 
497   bool IsDef = false;
498   for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) {
499     const MachineOperand &MO = MI->getOperand(i);
500     if (!MO.isReg() || MO.getReg() != ARM::VPR)
501       continue;
502 
503     if (MO.isDef()) {
504       CurrentPredicate.insert(MI);
505       IsDef = true;
506     } else {
507       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
508       return false;
509     }
510   }
511 
512   // If we find a vpr def that is not already predicated on the vctp, we've
513   // got disjoint predicates that may not be equivalent when we do the
514   // conversion.
515   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
516     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
517     return false;
518   }
519 
520   return true;
521 }
522 
523 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
524   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
525   if (!ST.hasLOB())
526     return false;
527 
528   MF = &mf;
529   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
530 
531   MLI = &getAnalysis<MachineLoopInfo>();
532   RDA = &getAnalysis<ReachingDefAnalysis>();
533   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
534   MRI = &MF->getRegInfo();
535   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
536   TRI = ST.getRegisterInfo();
537   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
538   BBUtils->computeAllBlockSizes();
539   BBUtils->adjustBBOffsetsAfter(&MF->front());
540 
541   bool Changed = false;
542   for (auto ML : *MLI) {
543     if (!ML->getParentLoop())
544       Changed |= ProcessLoop(ML);
545   }
546   Changed |= RevertNonLoops();
547   return Changed;
548 }
549 
550 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
551 
552   bool Changed = false;
553 
554   // Process inner loops first.
555   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
556     Changed |= ProcessLoop(*I);
557 
558   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
559              if (auto *Preheader = ML->getLoopPreheader())
560                dbgs() << " - " << Preheader->getName() << "\n";
561              else if (auto *Preheader = MLI->findLoopPreheader(ML))
562                dbgs() << " - " << Preheader->getName() << "\n";
563              for (auto *MBB : ML->getBlocks())
564                dbgs() << " - " << MBB->getName() << "\n";
565             );
566 
567   // Search the given block for a loop start instruction. If one isn't found,
568   // and there's only one predecessor block, search that one too.
569   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
570     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
571     for (auto &MI : *MBB) {
572       if (isLoopStart(MI))
573         return &MI;
574     }
575     if (MBB->pred_size() == 1)
576       return SearchForStart(*MBB->pred_begin());
577     return nullptr;
578   };
579 
580   LowOverheadLoop LoLoop(ML);
581   // Search the preheader for the start intrinsic.
582   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
583   // with potentially multiple set.loop.iterations, so we need to enable this.
584   if (auto *Preheader = ML->getLoopPreheader())
585     LoLoop.Start = SearchForStart(Preheader);
586   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
587     LoLoop.Start = SearchForStart(Preheader);
588   else
589     return false;
590 
591   // Find the low-overhead loop components and decide whether or not to fall
592   // back to a normal loop. Also look for a vctp instructions and decide
593   // whether we can convert that predicate using tail predication.
594   for (auto *MBB : reverse(ML->getBlocks())) {
595     for (auto &MI : *MBB) {
596       if (MI.getOpcode() == ARM::t2LoopDec)
597         LoLoop.Dec = &MI;
598       else if (MI.getOpcode() == ARM::t2LoopEnd)
599         LoLoop.End = &MI;
600       else if (isLoopStart(MI))
601         LoLoop.Start = &MI;
602       else if (MI.getDesc().isCall()) {
603         // TODO: Though the call will require LE to execute again, does this
604         // mean we should revert? Always executing LE hopefully should be
605         // faster than performing a sub,cmp,br or even subs,br.
606         LoLoop.Revert = true;
607         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
608       } else {
609         // Record VPR defs and build up their corresponding vpt blocks.
610         // Check we know how to tail predicate any mve instructions.
611         LoLoop.AnalyseMVEInst(&MI);
612       }
613 
614       // We need to ensure that LR is not used or defined inbetween LoopDec and
615       // LoopEnd.
616       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
617         continue;
618 
619       // If we find that LR has been written or read between LoopDec and
620       // LoopEnd, expect that the decremented value is being used else where.
621       // Because this value isn't actually going to be produced until the
622       // latch, by LE, we would need to generate a real sub. The value is also
623       // likely to be copied/reloaded for use of LoopEnd - in which in case
624       // we'd need to perform an add because it gets subtracted again by LE!
625       // The other option is to then generate the other form of LE which doesn't
626       // perform the sub.
627       for (auto &MO : MI.operands()) {
628         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
629             MO.getReg() == ARM::LR) {
630           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
631           LoLoop.Revert = true;
632           break;
633         }
634       }
635     }
636   }
637 
638   LLVM_DEBUG(LoLoop.dump());
639   if (!LoLoop.FoundAllComponents())
640     return false;
641 
642   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
643   Expand(LoLoop);
644   return true;
645 }
646 
647 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
648 // beq that branches to the exit branch.
649 // TODO: We could also try to generate a cbz if the value in LR is also in
650 // another low register.
651 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
652   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
653   MachineBasicBlock *MBB = MI->getParent();
654   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
655                                     TII->get(ARM::t2CMPri));
656   MIB.add(MI->getOperand(0));
657   MIB.addImm(0);
658   MIB.addImm(ARMCC::AL);
659   MIB.addReg(ARM::NoRegister);
660 
661   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
662   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
663     ARM::tBcc : ARM::t2Bcc;
664 
665   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
666   MIB.add(MI->getOperand(1));   // branch target
667   MIB.addImm(ARMCC::EQ);        // condition code
668   MIB.addReg(ARM::CPSR);
669   MI->eraseFromParent();
670 }
671 
672 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
673                                         bool SetFlags) const {
674   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
675   MachineBasicBlock *MBB = MI->getParent();
676 
677   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
678   if (SetFlags &&
679       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
680        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
681       SetFlags = false;
682 
683   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
684                                     TII->get(ARM::t2SUBri));
685   MIB.addDef(ARM::LR);
686   MIB.add(MI->getOperand(1));
687   MIB.add(MI->getOperand(2));
688   MIB.addImm(ARMCC::AL);
689   MIB.addReg(0);
690 
691   if (SetFlags) {
692     MIB.addReg(ARM::CPSR);
693     MIB->getOperand(5).setIsDef(true);
694   } else
695     MIB.addReg(0);
696 
697   MI->eraseFromParent();
698   return SetFlags;
699 }
700 
701 // Generate a subs, or sub and cmp, and a branch instead of an LE.
702 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
703   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
704 
705   MachineBasicBlock *MBB = MI->getParent();
706   // Create cmp
707   if (!SkipCmp) {
708     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
709                                       TII->get(ARM::t2CMPri));
710     MIB.addReg(ARM::LR);
711     MIB.addImm(0);
712     MIB.addImm(ARMCC::AL);
713     MIB.addReg(ARM::NoRegister);
714   }
715 
716   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
717   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
718     ARM::tBcc : ARM::t2Bcc;
719 
720   // Create bne
721   MachineInstrBuilder MIB =
722     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
723   MIB.add(MI->getOperand(1));   // branch target
724   MIB.addImm(ARMCC::NE);        // condition code
725   MIB.addReg(ARM::CPSR);
726   MI->eraseFromParent();
727 }
728 
729 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
730   MachineInstr *InsertPt = LoLoop.InsertPt;
731   MachineInstr *Start = LoLoop.Start;
732   MachineBasicBlock *MBB = InsertPt->getParent();
733   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
734   unsigned Opc = LoLoop.getStartOpcode();
735   MachineOperand &Count = LoLoop.getCount();
736 
737   MachineInstrBuilder MIB =
738     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
739 
740   MIB.addDef(ARM::LR);
741   MIB.add(Count);
742   if (!IsDo)
743     MIB.add(Start->getOperand(1));
744 
745   // When using tail-predication, try to delete the dead code that was used to
746   // calculate the number of loop iterations.
747   if (LoLoop.IsTailPredicationLegal()) {
748     SmallVector<MachineInstr*, 4> Killed;
749     SmallVector<MachineInstr*, 4> Dead;
750     if (auto *Def = RDA->getReachingMIDef(Start,
751                                           Start->getOperand(0).getReg())) {
752       Killed.push_back(Def);
753 
754       while (!Killed.empty()) {
755         MachineInstr *Def = Killed.back();
756         Killed.pop_back();
757         Dead.push_back(Def);
758         for (auto &MO : Def->operands()) {
759           if (!MO.isReg() || !MO.isKill())
760             continue;
761 
762           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
763           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
764             Killed.push_back(Kill);
765         }
766       }
767       for (auto *MI : Dead)
768         MI->eraseFromParent();
769     }
770   }
771 
772   // If we're inserting at a mov lr, then remove it as it's redundant.
773   if (InsertPt != Start)
774     InsertPt->eraseFromParent();
775   Start->eraseFromParent();
776   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
777   return &*MIB;
778 }
779 
780 // Goal is to optimise and clean-up these loops:
781 //
782 //   vector.body:
783 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
784 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
785 //     ..
786 //     $lr = MVE_DLSTP_32 renamable $r3
787 //
788 // The SUB is the old update of the loop iteration count expression, which
789 // is no longer needed. This sub is removed when the element count, which is in
790 // r3 in this example, is defined by an instruction in the loop, and it has
791 // no uses.
792 //
793 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
794   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
795   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
796 
797   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
798 
799   if (LoLoop.ML->getNumBlocks() != 1) {
800     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
801     return;
802   }
803 
804   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
805              LoLoop.VCTP->getOperand(1).dump());
806 
807   // Find the definition we are interested in removing, if there is one.
808   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
809   if (!Def)
810     return;
811 
812   // Bail if we define CPSR and it is not dead
813   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
814     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
815     return;
816   }
817 
818   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
819   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
820     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
821     return;
822   }
823 
824   // Bail if there are uses after this Def in the block.
825   SmallVector<MachineInstr*, 4> Uses;
826   RDA->getReachingLocalUses(Def, ElemCount, Uses);
827   if (Uses.size()) {
828     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
829     return;
830   }
831 
832   Uses.clear();
833   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
834 
835   // Remove Def if there are no uses, or if the only use is the VCTP
836   // instruction.
837   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
838     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
839                Def->dump());
840     Def->eraseFromParent();
841   }
842 }
843 
844 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
845   auto RemovePredicate = [](MachineInstr *MI) {
846     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
847     unsigned OpNum = MI->getNumOperands() - 1;
848     assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then &&
849            "Expected Then predicate!");
850     MI->getOperand(OpNum-1).setImm(ARMVCC::None);
851     MI->getOperand(OpNum).setReg(0);
852   };
853 
854   // There are a few scenarios which we have to fix up:
855   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
856   //    defs.
857   // 2) A VPT block which is only predicated by the vctp but has an internal
858   //    vpr def.
859   // 3) A VPT block which is predicated upon the vctp as well as another vpr
860   //    def.
861   // 4) A VPT block which is not predicated upon a vctp, but contains it and
862   //    all instructions within the block are predicated upon in.
863 
864   for (auto &Block : LoLoop.getVPTBlocks()) {
865     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
866     if (Block.HasNonUniformPredicate()) {
867       PredicatedMI *Divergent = Block.getDivergent();
868       if (isVCTP(Divergent->MI)) {
869         // The vctp will be removed, so the size of the vpt block needs to be
870         // modified.
871         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
872         Block.getVPST()->getOperand(0).setImm(Size);
873         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
874       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
875         // The VPT block has a non-uniform predicate but it's entry is guarded
876         // only by a vctp, which means we:
877         // - Need to remove the original vpst.
878         // - Then need to unpredicate any following instructions, until
879         //   we come across the divergent vpr def.
880         // - Insert a new vpst to predicate the instruction(s) that following
881         //   the divergent vpr def.
882         // TODO: We could be producing more VPT blocks than necessary and could
883         // fold the newly created one into a proceeding one.
884         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
885              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
886           RemovePredicate(&*I);
887 
888         unsigned Size = 0;
889         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
890         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
891         MachineInstr *InsertAt = nullptr;
892         while (I != E) {
893           InsertAt = &*I;
894           ++Size;
895           ++I;
896         }
897         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
898                                           InsertAt->getDebugLoc(),
899                                           TII->get(ARM::MVE_VPST));
900         MIB.addImm(getARMVPTBlockMask(Size));
901         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
902         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
903         Block.getVPST()->eraseFromParent();
904       }
905     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
906       // A vpt block which is only predicated upon vctp and has no internal vpr
907       // defs:
908       // - Remove vpst.
909       // - Unpredicate the remaining instructions.
910       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
911       Block.getVPST()->eraseFromParent();
912       for (auto &PredMI : Insts)
913         RemovePredicate(PredMI.MI);
914     }
915   }
916 
917   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
918   LoLoop.VCTP->eraseFromParent();
919 }
920 
921 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
922 
923   // Combine the LoopDec and LoopEnd instructions into LE(TP).
924   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
925     MachineInstr *End = LoLoop.End;
926     MachineBasicBlock *MBB = End->getParent();
927     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
928       ARM::MVE_LETP : ARM::t2LEUpdate;
929     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
930                                       TII->get(Opc));
931     MIB.addDef(ARM::LR);
932     MIB.add(End->getOperand(0));
933     MIB.add(End->getOperand(1));
934     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
935 
936     LoLoop.End->eraseFromParent();
937     LoLoop.Dec->eraseFromParent();
938     return &*MIB;
939   };
940 
941   // TODO: We should be able to automatically remove these branches before we
942   // get here - probably by teaching analyzeBranch about the pseudo
943   // instructions.
944   // If there is an unconditional branch, after I, that just branches to the
945   // next block, remove it.
946   auto RemoveDeadBranch = [](MachineInstr *I) {
947     MachineBasicBlock *BB = I->getParent();
948     MachineInstr *Terminator = &BB->instr_back();
949     if (Terminator->isUnconditionalBranch() && I != Terminator) {
950       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
951       if (BB->isLayoutSuccessor(Succ)) {
952         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
953         Terminator->eraseFromParent();
954       }
955     }
956   };
957 
958   if (LoLoop.Revert) {
959     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
960       RevertWhile(LoLoop.Start);
961     else
962       LoLoop.Start->eraseFromParent();
963     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
964     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
965   } else {
966     LoLoop.Start = ExpandLoopStart(LoLoop);
967     RemoveDeadBranch(LoLoop.Start);
968     LoLoop.End = ExpandLoopEnd(LoLoop);
969     RemoveDeadBranch(LoLoop.End);
970     if (LoLoop.IsTailPredicationLegal()) {
971       RemoveLoopUpdate(LoLoop);
972       ConvertVPTBlocks(LoLoop);
973     }
974   }
975 }
976 
977 bool ARMLowOverheadLoops::RevertNonLoops() {
978   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
979   bool Changed = false;
980 
981   for (auto &MBB : *MF) {
982     SmallVector<MachineInstr*, 4> Starts;
983     SmallVector<MachineInstr*, 4> Decs;
984     SmallVector<MachineInstr*, 4> Ends;
985 
986     for (auto &I : MBB) {
987       if (isLoopStart(I))
988         Starts.push_back(&I);
989       else if (I.getOpcode() == ARM::t2LoopDec)
990         Decs.push_back(&I);
991       else if (I.getOpcode() == ARM::t2LoopEnd)
992         Ends.push_back(&I);
993     }
994 
995     if (Starts.empty() && Decs.empty() && Ends.empty())
996       continue;
997 
998     Changed = true;
999 
1000     for (auto *Start : Starts) {
1001       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1002         RevertWhile(Start);
1003       else
1004         Start->eraseFromParent();
1005     }
1006     for (auto *Dec : Decs)
1007       RevertLoopDec(Dec);
1008 
1009     for (auto *End : Ends)
1010       RevertLoopEnd(End);
1011   }
1012   return Changed;
1013 }
1014 
1015 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1016   return new ARMLowOverheadLoops();
1017 }
1018