xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 0efc9e5a8cc12b9cb30adf2a3dbb14ffbc60e338)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #include "ARM.h"
25 #include "ARMBaseInstrInfo.h"
26 #include "ARMBaseRegisterInfo.h"
27 #include "ARMBasicBlockInfo.h"
28 #include "ARMSubtarget.h"
29 #include "llvm/ADT/SetOperations.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineLoopUtils.h"
34 #include "llvm/CodeGen/MachineRegisterInfo.h"
35 #include "llvm/CodeGen/Passes.h"
36 #include "llvm/CodeGen/ReachingDefAnalysis.h"
37 #include "llvm/MC/MCInstrDesc.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "arm-low-overhead-loops"
42 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
43 
44 namespace {
45 
46   struct PredicatedMI {
47     MachineInstr *MI = nullptr;
48     SetVector<MachineInstr*> Predicates;
49 
50   public:
51     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
52     MI(I) {
53       Predicates.insert(Preds.begin(), Preds.end());
54     }
55   };
56 
57   // Represent a VPT block, a list of instructions that begins with a VPST and
58   // has a maximum of four proceeding instructions. All instructions within the
59   // block are predicated upon the vpr and we allow instructions to define the
60   // vpr within in the block too.
61   class VPTBlock {
62     std::unique_ptr<PredicatedMI> VPST;
63     PredicatedMI *Divergent = nullptr;
64     SmallVector<PredicatedMI, 4> Insts;
65 
66   public:
67     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
68       VPST = std::make_unique<PredicatedMI>(MI, Preds);
69     }
70 
71     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
72       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
73       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
74         Divergent = &Insts.back();
75         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
76       }
77       Insts.emplace_back(MI, Preds);
78       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
79     }
80 
81     // Have we found an instruction within the block which defines the vpr? If
82     // so, not all the instructions in the block will have the same predicate.
83     bool HasNonUniformPredicate() const {
84       return Divergent != nullptr;
85     }
86 
87     // Is the given instruction part of the predicate set controlling the entry
88     // to the block.
89     bool IsPredicatedOn(MachineInstr *MI) const {
90       return VPST->Predicates.count(MI);
91     }
92 
93     // Is the given instruction the only predicate which controls the entry to
94     // the block.
95     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
96       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
97     }
98 
99     unsigned size() const { return Insts.size(); }
100     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
101     MachineInstr *getVPST() const { return VPST->MI; }
102     PredicatedMI *getDivergent() const { return Divergent; }
103   };
104 
105   struct LowOverheadLoop {
106 
107     MachineLoop *ML = nullptr;
108     MachineFunction *MF = nullptr;
109     MachineInstr *InsertPt = nullptr;
110     MachineInstr *Start = nullptr;
111     MachineInstr *Dec = nullptr;
112     MachineInstr *End = nullptr;
113     MachineInstr *VCTP = nullptr;
114     VPTBlock *CurrentBlock = nullptr;
115     SetVector<MachineInstr*> CurrentPredicate;
116     SmallVector<VPTBlock, 4> VPTBlocks;
117     bool Revert = false;
118     bool CannotTailPredicate = false;
119 
120     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
121       MF = ML->getHeader()->getParent();
122     }
123 
124     bool RecordVPTBlocks(MachineInstr *MI);
125 
126     // If this is an MVE instruction, check that we know how to use tail
127     // predication with it.
128     void AnalyseMVEInst(MachineInstr *MI) {
129       if (CannotTailPredicate)
130         return;
131 
132       if (!RecordVPTBlocks(MI)) {
133         CannotTailPredicate = true;
134         return;
135       }
136 
137       const MCInstrDesc &MCID = MI->getDesc();
138       uint64_t Flags = MCID.TSFlags;
139       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
140         return;
141 
142       if ((Flags & ARMII::ValidForTailPredication) == 0) {
143         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
144         CannotTailPredicate = true;
145       }
146     }
147 
148     bool IsTailPredicationLegal() const {
149       // For now, let's keep things really simple and only support a single
150       // block for tail predication.
151       return !Revert && FoundAllComponents() && VCTP &&
152              !CannotTailPredicate && ML->getNumBlocks() == 1;
153     }
154 
155     bool ValidateTailPredicate(MachineInstr *StartInsertPt,
156                                ReachingDefAnalysis *RDA,
157                                MachineLoopInfo *MLI);
158 
159     // Is it safe to define LR with DLS/WLS?
160     // LR can be defined if it is the operand to start, because it's the same
161     // value, or if it's going to be equivalent to the operand to Start.
162     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
163 
164     // Check the branch targets are within range and we satisfy our
165     // restrictions.
166     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
167                        MachineLoopInfo *MLI);
168 
169     bool FoundAllComponents() const {
170       return Start && Dec && End;
171     }
172 
173     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
174 
175     // Return the loop iteration count, or the number of elements if we're tail
176     // predicating.
177     MachineOperand &getCount() {
178       return IsTailPredicationLegal() ?
179         VCTP->getOperand(1) : Start->getOperand(0);
180     }
181 
182     unsigned getStartOpcode() const {
183       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
184       if (!IsTailPredicationLegal())
185         return IsDo ? ARM::t2DLS : ARM::t2WLS;
186 
187       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
188     }
189 
190     void dump() const {
191       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
192       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
193       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
194       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
195       if (!FoundAllComponents())
196         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
197       else if (!(Start && Dec && End))
198         dbgs() << "ARM Loops: Failed to find all loop components.\n";
199     }
200   };
201 
202   class ARMLowOverheadLoops : public MachineFunctionPass {
203     MachineFunction           *MF = nullptr;
204     MachineLoopInfo           *MLI = nullptr;
205     ReachingDefAnalysis       *RDA = nullptr;
206     const ARMBaseInstrInfo    *TII = nullptr;
207     MachineRegisterInfo       *MRI = nullptr;
208     const TargetRegisterInfo  *TRI = nullptr;
209     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
210 
211   public:
212     static char ID;
213 
214     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
215 
216     void getAnalysisUsage(AnalysisUsage &AU) const override {
217       AU.setPreservesCFG();
218       AU.addRequired<MachineLoopInfo>();
219       AU.addRequired<ReachingDefAnalysis>();
220       MachineFunctionPass::getAnalysisUsage(AU);
221     }
222 
223     bool runOnMachineFunction(MachineFunction &MF) override;
224 
225     MachineFunctionProperties getRequiredProperties() const override {
226       return MachineFunctionProperties().set(
227           MachineFunctionProperties::Property::NoVRegs).set(
228           MachineFunctionProperties::Property::TracksLiveness);
229     }
230 
231     StringRef getPassName() const override {
232       return ARM_LOW_OVERHEAD_LOOPS_NAME;
233     }
234 
235   private:
236     bool ProcessLoop(MachineLoop *ML);
237 
238     bool RevertNonLoops();
239 
240     void RevertWhile(MachineInstr *MI) const;
241 
242     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
243 
244     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
245 
246     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
247 
248     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
249 
250     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
251 
252     void Expand(LowOverheadLoop &LoLoop);
253 
254   };
255 }
256 
257 char ARMLowOverheadLoops::ID = 0;
258 
259 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
260                 false, false)
261 
262 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
263   // We can define LR because LR already contains the same value.
264   if (Start->getOperand(0).getReg() == ARM::LR)
265     return Start;
266 
267   unsigned CountReg = Start->getOperand(0).getReg();
268   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
269     return MI->getOpcode() == ARM::tMOVr &&
270            MI->getOperand(0).getReg() == ARM::LR &&
271            MI->getOperand(1).getReg() == CountReg &&
272            MI->getOperand(2).getImm() == ARMCC::AL;
273    };
274 
275   MachineBasicBlock *MBB = Start->getParent();
276 
277   // Find an insertion point:
278   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
279   //   to Count before Start, we can insert at that mov.
280   if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR))
281     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
282       return LRDef;
283 
284   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
285   //   to Count after Start, we can insert at that mov.
286   if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR))
287     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
288       return LRDef;
289 
290   // We've found no suitable LR def and Start doesn't use LR directly. Can we
291   // just define LR anyway?
292   if (!RDA->isRegUsedAfter(Start, ARM::LR))
293     return Start;
294 
295   return nullptr;
296 }
297 
298 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must
299 // not define a register that is used by any instructions, after and including,
300 // 'To'. These instructions also must not redefine any of Froms operands.
301 template<typename Iterator>
302 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) {
303   SmallSet<int, 2> Defs;
304   // First check that From would compute the same value if moved.
305   for (auto &MO : From->operands()) {
306     if (!MO.isReg() || MO.isUndef() || !MO.getReg())
307       continue;
308     if (MO.isDef())
309       Defs.insert(MO.getReg());
310     else if (!RDA->hasSameReachingDef(From, To, MO.getReg()))
311       return false;
312   }
313 
314   // Now walk checking that the rest of the instructions will compute the same
315   // value.
316   for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) {
317     for (auto &MO : I->operands())
318       if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg()))
319         return false;
320   }
321   return true;
322 }
323 
324 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
325 		                            ReachingDefAnalysis *RDA,
326 		                            MachineLoopInfo *MLI) {
327   // All predication within the loop should be based on vctp. If the block
328   // isn't predicated on entry, check whether the vctp is within the block
329   // and that all other instructions are then predicated on it.
330   for (auto &Block : VPTBlocks) {
331     if (Block.IsPredicatedOn(VCTP))
332       continue;
333     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI))
334       return false;
335     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
336     for (auto &PredMI : Insts) {
337       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
338         continue;
339       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
340                         << " - which is predicated on:\n";
341                         for (auto *MI : PredMI.Predicates)
342                           dbgs() << "   - " << *MI;
343                  );
344       return false;
345     }
346   }
347 
348   // For tail predication, we need to provide the number of elements, instead
349   // of the iteration count, to the loop start instruction. The number of
350   // elements is provided to the vctp instruction, so we need to check that
351   // we can use this register at InsertPt.
352   Register NumElements = VCTP->getOperand(1).getReg();
353 
354   // If the register is defined within loop, then we can't perform TP.
355   // TODO: Check whether this is just a mov of a register that would be
356   // available.
357   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
358     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
359     return false;
360   }
361 
362   // The element count register maybe defined after InsertPt, in which case we
363   // need to try to move either InsertPt or the def so that the [w|d]lstp can
364   // use the value.
365   MachineBasicBlock *InsertBB = InsertPt->getParent();
366   if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) {
367     if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) {
368       if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) {
369         ElemDef->removeFromParent();
370         InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef);
371         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
372                    << *ElemDef);
373       } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) {
374         InsertPt->removeFromParent();
375         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt);
376         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
377       } else
378         return false;
379     }
380   }
381 
382   // Especially in the case of while loops, InsertBB may not be the
383   // preheader, so we need to check that the register isn't redefined
384   // before entering the loop.
385   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
386                                       Register NumElements) {
387     // NumElements is redefined in this block.
388     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
389       return true;
390 
391     // Don't continue searching up through multiple predecessors.
392     if (MBB->pred_size() > 1)
393       return true;
394 
395     return false;
396   };
397 
398   // First, find the block that looks like the preheader.
399   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
400   if (!MBB)
401     return false;
402 
403   // Then search backwards for a def, until we get to InsertBB.
404   while (MBB != InsertBB) {
405     if (CannotProvideElements(MBB, NumElements))
406       return false;
407     MBB = *MBB->pred_begin();
408   }
409 
410   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
411   return true;
412 }
413 
414 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
415                                     ReachingDefAnalysis *RDA,
416                                     MachineLoopInfo *MLI) {
417   if (Revert)
418     return;
419 
420   if (!End->getOperand(1).isMBB())
421     report_fatal_error("Expected LoopEnd to target basic block");
422 
423   // TODO Maybe there's cases where the target doesn't have to be the header,
424   // but for now be safe and revert.
425   if (End->getOperand(1).getMBB() != ML->getHeader()) {
426     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
427     Revert = true;
428     return;
429   }
430 
431   // The WLS and LE instructions have 12-bits for the label offset. WLS
432   // requires a positive offset, while LE uses negative.
433   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
434       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
435     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
436     Revert = true;
437     return;
438   }
439 
440   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
441       (BBUtils->getOffsetOf(Start) >
442        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
443        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
444     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
445     Revert = true;
446     return;
447   }
448 
449   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
450   if (!InsertPt) {
451     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
452     Revert = true;
453     return;
454   } else
455     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
456 
457   if (!IsTailPredicationLegal()) {
458     LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n");
459     return;
460   }
461 
462   assert(ML->getBlocks().size() == 1 &&
463          "Shouldn't be processing a loop with more than one block");
464   CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI);
465   LLVM_DEBUG(if (CannotTailPredicate)
466              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
467 }
468 
469 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) {
470   // Only support a single vctp.
471   if (isVCTP(MI) && VCTP)
472     return false;
473 
474   // Start a new vpt block when we discover a vpt.
475   if (MI->getOpcode() == ARM::MVE_VPST) {
476     VPTBlocks.emplace_back(MI, CurrentPredicate);
477     CurrentBlock = &VPTBlocks.back();
478     return true;
479   }
480 
481   if (isVCTP(MI))
482     VCTP = MI;
483 
484   unsigned VPROpNum = MI->getNumOperands() - 1;
485   bool IsUse = false;
486   if (MI->getOperand(VPROpNum).isReg() &&
487       MI->getOperand(VPROpNum).getReg() == ARM::VPR &&
488       MI->getOperand(VPROpNum).isUse()) {
489     // If this instruction is predicated by VPR, it will be its last
490     // operand.  Also check that it's only 'Then' predicated.
491     if (!MI->getOperand(VPROpNum-1).isImm() ||
492         MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) {
493       LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: "
494                  << *MI);
495       return false;
496     }
497     CurrentBlock->addInst(MI, CurrentPredicate);
498     IsUse = true;
499   }
500 
501   bool IsDef = false;
502   for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) {
503     const MachineOperand &MO = MI->getOperand(i);
504     if (!MO.isReg() || MO.getReg() != ARM::VPR)
505       continue;
506 
507     if (MO.isDef()) {
508       CurrentPredicate.insert(MI);
509       IsDef = true;
510     } else {
511       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
512       return false;
513     }
514   }
515 
516   // If we find a vpr def that is not already predicated on the vctp, we've
517   // got disjoint predicates that may not be equivalent when we do the
518   // conversion.
519   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
520     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
521     return false;
522   }
523 
524   return true;
525 }
526 
527 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
528   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
529   if (!ST.hasLOB())
530     return false;
531 
532   MF = &mf;
533   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
534 
535   MLI = &getAnalysis<MachineLoopInfo>();
536   RDA = &getAnalysis<ReachingDefAnalysis>();
537   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
538   MRI = &MF->getRegInfo();
539   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
540   TRI = ST.getRegisterInfo();
541   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
542   BBUtils->computeAllBlockSizes();
543   BBUtils->adjustBBOffsetsAfter(&MF->front());
544 
545   bool Changed = false;
546   for (auto ML : *MLI) {
547     if (!ML->getParentLoop())
548       Changed |= ProcessLoop(ML);
549   }
550   Changed |= RevertNonLoops();
551   return Changed;
552 }
553 
554 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
555 
556   bool Changed = false;
557 
558   // Process inner loops first.
559   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
560     Changed |= ProcessLoop(*I);
561 
562   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
563              if (auto *Preheader = ML->getLoopPreheader())
564                dbgs() << " - " << Preheader->getName() << "\n";
565              else if (auto *Preheader = MLI->findLoopPreheader(ML))
566                dbgs() << " - " << Preheader->getName() << "\n";
567              for (auto *MBB : ML->getBlocks())
568                dbgs() << " - " << MBB->getName() << "\n";
569             );
570 
571   // Search the given block for a loop start instruction. If one isn't found,
572   // and there's only one predecessor block, search that one too.
573   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
574     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
575     for (auto &MI : *MBB) {
576       if (isLoopStart(MI))
577         return &MI;
578     }
579     if (MBB->pred_size() == 1)
580       return SearchForStart(*MBB->pred_begin());
581     return nullptr;
582   };
583 
584   LowOverheadLoop LoLoop(ML);
585   // Search the preheader for the start intrinsic.
586   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
587   // with potentially multiple set.loop.iterations, so we need to enable this.
588   if (auto *Preheader = ML->getLoopPreheader())
589     LoLoop.Start = SearchForStart(Preheader);
590   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
591     LoLoop.Start = SearchForStart(Preheader);
592   else
593     return false;
594 
595   // Find the low-overhead loop components and decide whether or not to fall
596   // back to a normal loop. Also look for a vctp instructions and decide
597   // whether we can convert that predicate using tail predication.
598   for (auto *MBB : reverse(ML->getBlocks())) {
599     for (auto &MI : *MBB) {
600       if (MI.getOpcode() == ARM::t2LoopDec)
601         LoLoop.Dec = &MI;
602       else if (MI.getOpcode() == ARM::t2LoopEnd)
603         LoLoop.End = &MI;
604       else if (isLoopStart(MI))
605         LoLoop.Start = &MI;
606       else if (MI.getDesc().isCall()) {
607         // TODO: Though the call will require LE to execute again, does this
608         // mean we should revert? Always executing LE hopefully should be
609         // faster than performing a sub,cmp,br or even subs,br.
610         LoLoop.Revert = true;
611         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
612       } else {
613         // Record VPR defs and build up their corresponding vpt blocks.
614         // Check we know how to tail predicate any mve instructions.
615         LoLoop.AnalyseMVEInst(&MI);
616       }
617 
618       // We need to ensure that LR is not used or defined inbetween LoopDec and
619       // LoopEnd.
620       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
621         continue;
622 
623       // If we find that LR has been written or read between LoopDec and
624       // LoopEnd, expect that the decremented value is being used else where.
625       // Because this value isn't actually going to be produced until the
626       // latch, by LE, we would need to generate a real sub. The value is also
627       // likely to be copied/reloaded for use of LoopEnd - in which in case
628       // we'd need to perform an add because it gets subtracted again by LE!
629       // The other option is to then generate the other form of LE which doesn't
630       // perform the sub.
631       for (auto &MO : MI.operands()) {
632         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
633             MO.getReg() == ARM::LR) {
634           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
635           LoLoop.Revert = true;
636           break;
637         }
638       }
639     }
640   }
641 
642   LLVM_DEBUG(LoLoop.dump());
643   if (!LoLoop.FoundAllComponents())
644     return false;
645 
646   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
647   Expand(LoLoop);
648   return true;
649 }
650 
651 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
652 // beq that branches to the exit branch.
653 // TODO: We could also try to generate a cbz if the value in LR is also in
654 // another low register.
655 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
656   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
657   MachineBasicBlock *MBB = MI->getParent();
658   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
659                                     TII->get(ARM::t2CMPri));
660   MIB.add(MI->getOperand(0));
661   MIB.addImm(0);
662   MIB.addImm(ARMCC::AL);
663   MIB.addReg(ARM::NoRegister);
664 
665   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
666   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
667     ARM::tBcc : ARM::t2Bcc;
668 
669   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
670   MIB.add(MI->getOperand(1));   // branch target
671   MIB.addImm(ARMCC::EQ);        // condition code
672   MIB.addReg(ARM::CPSR);
673   MI->eraseFromParent();
674 }
675 
676 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
677                                         bool SetFlags) const {
678   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
679   MachineBasicBlock *MBB = MI->getParent();
680 
681   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
682   if (SetFlags &&
683       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
684        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
685       SetFlags = false;
686 
687   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
688                                     TII->get(ARM::t2SUBri));
689   MIB.addDef(ARM::LR);
690   MIB.add(MI->getOperand(1));
691   MIB.add(MI->getOperand(2));
692   MIB.addImm(ARMCC::AL);
693   MIB.addReg(0);
694 
695   if (SetFlags) {
696     MIB.addReg(ARM::CPSR);
697     MIB->getOperand(5).setIsDef(true);
698   } else
699     MIB.addReg(0);
700 
701   MI->eraseFromParent();
702   return SetFlags;
703 }
704 
705 // Generate a subs, or sub and cmp, and a branch instead of an LE.
706 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
707   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
708 
709   MachineBasicBlock *MBB = MI->getParent();
710   // Create cmp
711   if (!SkipCmp) {
712     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
713                                       TII->get(ARM::t2CMPri));
714     MIB.addReg(ARM::LR);
715     MIB.addImm(0);
716     MIB.addImm(ARMCC::AL);
717     MIB.addReg(ARM::NoRegister);
718   }
719 
720   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
721   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
722     ARM::tBcc : ARM::t2Bcc;
723 
724   // Create bne
725   MachineInstrBuilder MIB =
726     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
727   MIB.add(MI->getOperand(1));   // branch target
728   MIB.addImm(ARMCC::NE);        // condition code
729   MIB.addReg(ARM::CPSR);
730   MI->eraseFromParent();
731 }
732 
733 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
734   MachineInstr *InsertPt = LoLoop.InsertPt;
735   MachineInstr *Start = LoLoop.Start;
736   MachineBasicBlock *MBB = InsertPt->getParent();
737   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
738   unsigned Opc = LoLoop.getStartOpcode();
739   MachineOperand &Count = LoLoop.getCount();
740 
741   MachineInstrBuilder MIB =
742     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
743 
744   MIB.addDef(ARM::LR);
745   MIB.add(Count);
746   if (!IsDo)
747     MIB.add(Start->getOperand(1));
748 
749   // When using tail-predication, try to delete the dead code that was used to
750   // calculate the number of loop iterations.
751   if (LoLoop.IsTailPredicationLegal()) {
752     SmallVector<MachineInstr*, 4> Killed;
753     SmallVector<MachineInstr*, 4> Dead;
754     if (auto *Def = RDA->getReachingMIDef(Start,
755                                           Start->getOperand(0).getReg())) {
756       Killed.push_back(Def);
757 
758       while (!Killed.empty()) {
759         MachineInstr *Def = Killed.back();
760         Killed.pop_back();
761         Dead.push_back(Def);
762         for (auto &MO : Def->operands()) {
763           if (!MO.isReg() || !MO.isKill())
764             continue;
765 
766           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
767           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
768             Killed.push_back(Kill);
769         }
770       }
771       for (auto *MI : Dead)
772         MI->eraseFromParent();
773     }
774   }
775 
776   // If we're inserting at a mov lr, then remove it as it's redundant.
777   if (InsertPt != Start)
778     InsertPt->eraseFromParent();
779   Start->eraseFromParent();
780   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
781   return &*MIB;
782 }
783 
784 // Goal is to optimise and clean-up these loops:
785 //
786 //   vector.body:
787 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
788 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
789 //     ..
790 //     $lr = MVE_DLSTP_32 renamable $r3
791 //
792 // The SUB is the old update of the loop iteration count expression, which
793 // is no longer needed. This sub is removed when the element count, which is in
794 // r3 in this example, is defined by an instruction in the loop, and it has
795 // no uses.
796 //
797 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
798   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
799   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
800 
801   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
802 
803   if (LoLoop.ML->getNumBlocks() != 1) {
804     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
805     return;
806   }
807 
808   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
809              LoLoop.VCTP->getOperand(1).dump());
810 
811   // Find the definition we are interested in removing, if there is one.
812   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
813   if (!Def)
814     return;
815 
816   // Bail if we define CPSR and it is not dead
817   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
818     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
819     return;
820   }
821 
822   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
823   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
824     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
825     return;
826   }
827 
828   // Bail if there are uses after this Def in the block.
829   SmallVector<MachineInstr*, 4> Uses;
830   RDA->getReachingLocalUses(Def, ElemCount, Uses);
831   if (Uses.size()) {
832     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
833     return;
834   }
835 
836   Uses.clear();
837   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
838 
839   // Remove Def if there are no uses, or if the only use is the VCTP
840   // instruction.
841   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
842     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
843                Def->dump());
844     Def->eraseFromParent();
845   }
846 }
847 
848 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
849   auto RemovePredicate = [](MachineInstr *MI) {
850     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
851     unsigned OpNum = MI->getNumOperands() - 1;
852     assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then &&
853            "Expected Then predicate!");
854     MI->getOperand(OpNum-1).setImm(ARMVCC::None);
855     MI->getOperand(OpNum).setReg(0);
856   };
857 
858   // There are a few scenarios which we have to fix up:
859   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
860   //    defs.
861   // 2) A VPT block which is only predicated by the vctp but has an internal
862   //    vpr def.
863   // 3) A VPT block which is predicated upon the vctp as well as another vpr
864   //    def.
865   // 4) A VPT block which is not predicated upon a vctp, but contains it and
866   //    all instructions within the block are predicated upon in.
867 
868   for (auto &Block : LoLoop.getVPTBlocks()) {
869     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
870     if (Block.HasNonUniformPredicate()) {
871       PredicatedMI *Divergent = Block.getDivergent();
872       if (isVCTP(Divergent->MI)) {
873         // The vctp will be removed, so the size of the vpt block needs to be
874         // modified.
875         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
876         Block.getVPST()->getOperand(0).setImm(Size);
877         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
878       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
879         // The VPT block has a non-uniform predicate but it's entry is guarded
880         // only by a vctp, which means we:
881         // - Need to remove the original vpst.
882         // - Then need to unpredicate any following instructions, until
883         //   we come across the divergent vpr def.
884         // - Insert a new vpst to predicate the instruction(s) that following
885         //   the divergent vpr def.
886         // TODO: We could be producing more VPT blocks than necessary and could
887         // fold the newly created one into a proceeding one.
888         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
889              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
890           RemovePredicate(&*I);
891 
892         unsigned Size = 0;
893         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
894         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
895         MachineInstr *InsertAt = nullptr;
896         while (I != E) {
897           InsertAt = &*I;
898           ++Size;
899           ++I;
900         }
901         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
902                                           InsertAt->getDebugLoc(),
903                                           TII->get(ARM::MVE_VPST));
904         MIB.addImm(getARMVPTBlockMask(Size));
905         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
906         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
907         Block.getVPST()->eraseFromParent();
908       }
909     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
910       // A vpt block which is only predicated upon vctp and has no internal vpr
911       // defs:
912       // - Remove vpst.
913       // - Unpredicate the remaining instructions.
914       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
915       Block.getVPST()->eraseFromParent();
916       for (auto &PredMI : Insts)
917         RemovePredicate(PredMI.MI);
918     }
919   }
920 
921   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
922   LoLoop.VCTP->eraseFromParent();
923 }
924 
925 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
926 
927   // Combine the LoopDec and LoopEnd instructions into LE(TP).
928   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
929     MachineInstr *End = LoLoop.End;
930     MachineBasicBlock *MBB = End->getParent();
931     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
932       ARM::MVE_LETP : ARM::t2LEUpdate;
933     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
934                                       TII->get(Opc));
935     MIB.addDef(ARM::LR);
936     MIB.add(End->getOperand(0));
937     MIB.add(End->getOperand(1));
938     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
939 
940     LoLoop.End->eraseFromParent();
941     LoLoop.Dec->eraseFromParent();
942     return &*MIB;
943   };
944 
945   // TODO: We should be able to automatically remove these branches before we
946   // get here - probably by teaching analyzeBranch about the pseudo
947   // instructions.
948   // If there is an unconditional branch, after I, that just branches to the
949   // next block, remove it.
950   auto RemoveDeadBranch = [](MachineInstr *I) {
951     MachineBasicBlock *BB = I->getParent();
952     MachineInstr *Terminator = &BB->instr_back();
953     if (Terminator->isUnconditionalBranch() && I != Terminator) {
954       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
955       if (BB->isLayoutSuccessor(Succ)) {
956         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
957         Terminator->eraseFromParent();
958       }
959     }
960   };
961 
962   if (LoLoop.Revert) {
963     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
964       RevertWhile(LoLoop.Start);
965     else
966       LoLoop.Start->eraseFromParent();
967     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
968     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
969   } else {
970     LoLoop.Start = ExpandLoopStart(LoLoop);
971     RemoveDeadBranch(LoLoop.Start);
972     LoLoop.End = ExpandLoopEnd(LoLoop);
973     RemoveDeadBranch(LoLoop.End);
974     if (LoLoop.IsTailPredicationLegal()) {
975       RemoveLoopUpdate(LoLoop);
976       ConvertVPTBlocks(LoLoop);
977     }
978   }
979 }
980 
981 bool ARMLowOverheadLoops::RevertNonLoops() {
982   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
983   bool Changed = false;
984 
985   for (auto &MBB : *MF) {
986     SmallVector<MachineInstr*, 4> Starts;
987     SmallVector<MachineInstr*, 4> Decs;
988     SmallVector<MachineInstr*, 4> Ends;
989 
990     for (auto &I : MBB) {
991       if (isLoopStart(I))
992         Starts.push_back(&I);
993       else if (I.getOpcode() == ARM::t2LoopDec)
994         Decs.push_back(&I);
995       else if (I.getOpcode() == ARM::t2LoopEnd)
996         Ends.push_back(&I);
997     }
998 
999     if (Starts.empty() && Decs.empty() && Ends.empty())
1000       continue;
1001 
1002     Changed = true;
1003 
1004     for (auto *Start : Starts) {
1005       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1006         RevertWhile(Start);
1007       else
1008         Start->eraseFromParent();
1009     }
1010     for (auto *Dec : Decs)
1011       RevertLoopDec(Dec);
1012 
1013     for (auto *End : Ends)
1014       RevertLoopEnd(End);
1015   }
1016   return Changed;
1017 }
1018 
1019 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1020   return new ARMLowOverheadLoops();
1021 }
1022