xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 049f9672d8566f0d0a115f11e2a53018ea502b10)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineLoopInfo.h"
27 #include "llvm/CodeGen/MachineLoopUtils.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/ReachingDefAnalysis.h"
31 #include "llvm/MC/MCInstrDesc.h"
32 
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "arm-low-overhead-loops"
36 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
37 
38 namespace {
39 
40   struct LowOverheadLoop {
41 
42     MachineLoop *ML = nullptr;
43     MachineFunction *MF = nullptr;
44     MachineInstr *InsertPt = nullptr;
45     MachineInstr *Start = nullptr;
46     MachineInstr *Dec = nullptr;
47     MachineInstr *End = nullptr;
48     MachineInstr *VCTP = nullptr;
49     SmallVector<MachineInstr*, 4> VPTUsers;
50     bool Revert = false;
51     bool FoundOneVCTP = false;
52     bool CannotTailPredicate = false;
53 
54     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
55       MF = ML->getHeader()->getParent();
56     }
57 
58     // For now, only support one vctp instruction. If we find multiple then
59     // we shouldn't perform tail predication.
60     void addVCTP(MachineInstr *MI) {
61       if (!VCTP) {
62         VCTP = MI;
63         FoundOneVCTP = true;
64       } else
65         FoundOneVCTP = false;
66     }
67 
68     // Check that nothing else is writing to VPR and record any insts
69     // reading the VPR.
70     void ScanForVPR(MachineInstr *MI) {
71       for (auto &MO : MI->operands()) {
72         if (!MO.isReg() || MO.getReg() != ARM::VPR)
73           continue;
74         if (MO.isUse())
75           VPTUsers.push_back(MI);
76         if (MO.isDef()) {
77           CannotTailPredicate = true;
78           break;
79         }
80       }
81     }
82 
83     // If this is an MVE instruction, check that we know how to use tail
84     // predication with it.
85     void CheckTPValidity(MachineInstr *MI) {
86       if (CannotTailPredicate)
87         return;
88 
89       const MCInstrDesc &MCID = MI->getDesc();
90       uint64_t Flags = MCID.TSFlags;
91       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
92         return;
93 
94       if ((Flags & ARMII::ValidForTailPredication) == 0) {
95         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
96         CannotTailPredicate = true;
97       }
98     }
99 
100     bool IsTailPredicationLegal() const {
101       // For now, let's keep things really simple and only support a single
102       // block for tail predication.
103       return !Revert && FoundAllComponents() && FoundOneVCTP &&
104              !CannotTailPredicate && ML->getNumBlocks() == 1;
105     }
106 
107     // Is it safe to define LR with DLS/WLS?
108     // LR can be defined if it is the operand to start, because it's the same
109     // value, or if it's going to be equivalent to the operand to Start.
110     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
111 
112     // Check the branch targets are within range and we satisfy our
113     // restrictions.
114     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
115                        MachineLoopInfo *MLI);
116 
117     bool FoundAllComponents() const {
118       return Start && Dec && End;
119     }
120 
121     // Return the loop iteration count, or the number of elements if we're tail
122     // predicating.
123     MachineOperand &getCount() {
124       return IsTailPredicationLegal() ?
125         VCTP->getOperand(1) : Start->getOperand(0);
126     }
127 
128     unsigned getStartOpcode() const {
129       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
130       if (!IsTailPredicationLegal())
131         return IsDo ? ARM::t2DLS : ARM::t2WLS;
132 
133       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
134     }
135 
136     void dump() const {
137       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
138       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
139       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
140       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
141       if (!FoundAllComponents())
142         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
143       else if (!(Start && Dec && End))
144         dbgs() << "ARM Loops: Failed to find all loop components.\n";
145     }
146   };
147 
148   class ARMLowOverheadLoops : public MachineFunctionPass {
149     MachineFunction           *MF = nullptr;
150     MachineLoopInfo           *MLI = nullptr;
151     ReachingDefAnalysis       *RDA = nullptr;
152     const ARMBaseInstrInfo    *TII = nullptr;
153     MachineRegisterInfo       *MRI = nullptr;
154     const TargetRegisterInfo  *TRI = nullptr;
155     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
156 
157   public:
158     static char ID;
159 
160     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
161 
162     void getAnalysisUsage(AnalysisUsage &AU) const override {
163       AU.setPreservesCFG();
164       AU.addRequired<MachineLoopInfo>();
165       AU.addRequired<ReachingDefAnalysis>();
166       MachineFunctionPass::getAnalysisUsage(AU);
167     }
168 
169     bool runOnMachineFunction(MachineFunction &MF) override;
170 
171     MachineFunctionProperties getRequiredProperties() const override {
172       return MachineFunctionProperties().set(
173           MachineFunctionProperties::Property::NoVRegs).set(
174           MachineFunctionProperties::Property::TracksLiveness);
175     }
176 
177     StringRef getPassName() const override {
178       return ARM_LOW_OVERHEAD_LOOPS_NAME;
179     }
180 
181   private:
182     bool ProcessLoop(MachineLoop *ML);
183 
184     bool RevertNonLoops();
185 
186     void RevertWhile(MachineInstr *MI) const;
187 
188     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
189 
190     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
191 
192     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
193 
194     void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
195 
196     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
197 
198     void Expand(LowOverheadLoop &LoLoop);
199 
200   };
201 }
202 
203 char ARMLowOverheadLoops::ID = 0;
204 
205 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
206                 false, false)
207 
208 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
209   // We can define LR because LR already contains the same value.
210   if (Start->getOperand(0).getReg() == ARM::LR)
211     return Start;
212 
213   unsigned CountReg = Start->getOperand(0).getReg();
214   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
215     return MI->getOpcode() == ARM::tMOVr &&
216            MI->getOperand(0).getReg() == ARM::LR &&
217            MI->getOperand(1).getReg() == CountReg &&
218            MI->getOperand(2).getImm() == ARMCC::AL;
219    };
220 
221   MachineBasicBlock *MBB = Start->getParent();
222 
223   // Find an insertion point:
224   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
225   //   to Count before Start, we can insert at that mov.
226   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
227   //   to Count after Start, we can insert at that mov.
228   if (auto *LRDef = RDA->getReachingMIDef(&MBB->back(), ARM::LR)) {
229     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
230       return LRDef;
231   }
232 
233   // We've found no suitable LR def and Start doesn't use LR directly. Can we
234   // just define LR anyway?
235   if (!RDA->isRegUsedAfter(Start, ARM::LR))
236     return Start;
237 
238   return nullptr;
239 }
240 
241 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
242                                     ReachingDefAnalysis *RDA,
243                                     MachineLoopInfo *MLI) {
244   if (Revert)
245     return;
246 
247   if (!End->getOperand(1).isMBB())
248     report_fatal_error("Expected LoopEnd to target basic block");
249 
250   // TODO Maybe there's cases where the target doesn't have to be the header,
251   // but for now be safe and revert.
252   if (End->getOperand(1).getMBB() != ML->getHeader()) {
253     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
254     Revert = true;
255     return;
256   }
257 
258   // The WLS and LE instructions have 12-bits for the label offset. WLS
259   // requires a positive offset, while LE uses negative.
260   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
261       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
262     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
263     Revert = true;
264     return;
265   }
266 
267   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
268       (BBUtils->getOffsetOf(Start) >
269        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
270        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
271     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
272     Revert = true;
273     return;
274   }
275 
276   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
277   if (!InsertPt) {
278     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
279     Revert = true;
280     return;
281   } else
282     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
283 
284   // For tail predication, we need to provide the number of elements, instead
285   // of the iteration count, to the loop start instruction. The number of
286   // elements is provided to the vctp instruction, so we need to check that
287   // we can use this register at InsertPt.
288   if (!IsTailPredicationLegal())
289     return;
290 
291   Register NumElements = VCTP->getOperand(1).getReg();
292 
293   // If the register is defined within loop, then we can't perform TP.
294   // TODO: Check whether this is just a mov of a register that would be
295   // available.
296   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
297     CannotTailPredicate = true;
298     return;
299   }
300 
301   // We can't perform TP if the register does not hold the same value at
302   // InsertPt as the liveout value.
303   MachineBasicBlock *InsertBB = InsertPt->getParent();
304   if  (!RDA->hasSameReachingDef(InsertPt, &InsertBB->back(),
305                                 NumElements)) {
306     CannotTailPredicate = true;
307     return;
308   }
309 
310   // Especially in the case of while loops, InsertBB may not be the
311   // preheader, so we need to check that the register isn't redefined
312   // before entering the loop.
313   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
314                                       Register NumElements) {
315     // NumElements is redefined in this block.
316     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
317       return true;
318 
319     // Don't continue searching up through multiple predecessors.
320     if (MBB->pred_size() > 1)
321       return true;
322 
323     return false;
324   };
325 
326   // First, find the block that looks like the preheader.
327   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
328   if (!MBB) {
329     CannotTailPredicate = true;
330     return;
331   }
332 
333   // Then search backwards for a def, until we get to InsertBB.
334   while (MBB != InsertBB) {
335     CannotTailPredicate = CannotProvideElements(MBB, NumElements);
336     if (CannotTailPredicate)
337       return;
338     MBB = *MBB->pred_begin();
339   }
340 
341   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n";
342                for (auto *MI : VPTUsers)
343                  dbgs() << " - " << *MI;);
344 }
345 
346 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
347   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
348   if (!ST.hasLOB())
349     return false;
350 
351   MF = &mf;
352   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
353 
354   MLI = &getAnalysis<MachineLoopInfo>();
355   RDA = &getAnalysis<ReachingDefAnalysis>();
356   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
357   MRI = &MF->getRegInfo();
358   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
359   TRI = ST.getRegisterInfo();
360   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
361   BBUtils->computeAllBlockSizes();
362   BBUtils->adjustBBOffsetsAfter(&MF->front());
363 
364   bool Changed = false;
365   for (auto ML : *MLI) {
366     if (!ML->getParentLoop())
367       Changed |= ProcessLoop(ML);
368   }
369   Changed |= RevertNonLoops();
370   return Changed;
371 }
372 
373 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
374 
375   bool Changed = false;
376 
377   // Process inner loops first.
378   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
379     Changed |= ProcessLoop(*I);
380 
381   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
382              if (auto *Preheader = ML->getLoopPreheader())
383                dbgs() << " - " << Preheader->getName() << "\n";
384              else if (auto *Preheader = MLI->findLoopPreheader(ML))
385                dbgs() << " - " << Preheader->getName() << "\n";
386              for (auto *MBB : ML->getBlocks())
387                dbgs() << " - " << MBB->getName() << "\n";
388             );
389 
390   // Search the given block for a loop start instruction. If one isn't found,
391   // and there's only one predecessor block, search that one too.
392   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
393     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
394     for (auto &MI : *MBB) {
395       if (isLoopStart(MI))
396         return &MI;
397     }
398     if (MBB->pred_size() == 1)
399       return SearchForStart(*MBB->pred_begin());
400     return nullptr;
401   };
402 
403   LowOverheadLoop LoLoop(ML);
404   // Search the preheader for the start intrinsic.
405   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
406   // with potentially multiple set.loop.iterations, so we need to enable this.
407   if (auto *Preheader = ML->getLoopPreheader())
408     LoLoop.Start = SearchForStart(Preheader);
409   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
410     LoLoop.Start = SearchForStart(Preheader);
411   else
412     return false;
413 
414   // Find the low-overhead loop components and decide whether or not to fall
415   // back to a normal loop. Also look for a vctp instructions and decide
416   // whether we can convert that predicate using tail predication.
417   for (auto *MBB : reverse(ML->getBlocks())) {
418     for (auto &MI : *MBB) {
419       if (MI.getOpcode() == ARM::t2LoopDec)
420         LoLoop.Dec = &MI;
421       else if (MI.getOpcode() == ARM::t2LoopEnd)
422         LoLoop.End = &MI;
423       else if (isLoopStart(MI))
424         LoLoop.Start = &MI;
425       else if (isVCTP(&MI))
426         LoLoop.addVCTP(&MI);
427       else if (MI.getDesc().isCall()) {
428         // TODO: Though the call will require LE to execute again, does this
429         // mean we should revert? Always executing LE hopefully should be
430         // faster than performing a sub,cmp,br or even subs,br.
431         LoLoop.Revert = true;
432         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
433       } else {
434         // Once we've found a vctp, record the users of vpr and check there's
435         // no more vpr defs.
436         if (LoLoop.FoundOneVCTP)
437           LoLoop.ScanForVPR(&MI);
438         // Check we know how to tail predicate any mve instructions.
439         LoLoop.CheckTPValidity(&MI);
440       }
441 
442       // We need to ensure that LR is not used or defined inbetween LoopDec and
443       // LoopEnd.
444       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
445         continue;
446 
447       // If we find that LR has been written or read between LoopDec and
448       // LoopEnd, expect that the decremented value is being used else where.
449       // Because this value isn't actually going to be produced until the
450       // latch, by LE, we would need to generate a real sub. The value is also
451       // likely to be copied/reloaded for use of LoopEnd - in which in case
452       // we'd need to perform an add because it gets subtracted again by LE!
453       // The other option is to then generate the other form of LE which doesn't
454       // perform the sub.
455       for (auto &MO : MI.operands()) {
456         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
457             MO.getReg() == ARM::LR) {
458           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
459           LoLoop.Revert = true;
460           break;
461         }
462       }
463     }
464   }
465 
466   LLVM_DEBUG(LoLoop.dump());
467   if (!LoLoop.FoundAllComponents())
468     return false;
469 
470   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
471   Expand(LoLoop);
472   return true;
473 }
474 
475 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
476 // beq that branches to the exit branch.
477 // TODO: We could also try to generate a cbz if the value in LR is also in
478 // another low register.
479 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
480   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
481   MachineBasicBlock *MBB = MI->getParent();
482   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
483                                     TII->get(ARM::t2CMPri));
484   MIB.add(MI->getOperand(0));
485   MIB.addImm(0);
486   MIB.addImm(ARMCC::AL);
487   MIB.addReg(ARM::NoRegister);
488 
489   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
490   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
491     ARM::tBcc : ARM::t2Bcc;
492 
493   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
494   MIB.add(MI->getOperand(1));   // branch target
495   MIB.addImm(ARMCC::EQ);        // condition code
496   MIB.addReg(ARM::CPSR);
497   MI->eraseFromParent();
498 }
499 
500 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
501                                         bool SetFlags) const {
502   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
503   MachineBasicBlock *MBB = MI->getParent();
504 
505   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
506   if (SetFlags &&
507       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
508        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
509       SetFlags = false;
510 
511   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
512                                     TII->get(ARM::t2SUBri));
513   MIB.addDef(ARM::LR);
514   MIB.add(MI->getOperand(1));
515   MIB.add(MI->getOperand(2));
516   MIB.addImm(ARMCC::AL);
517   MIB.addReg(0);
518 
519   if (SetFlags) {
520     MIB.addReg(ARM::CPSR);
521     MIB->getOperand(5).setIsDef(true);
522   } else
523     MIB.addReg(0);
524 
525   MI->eraseFromParent();
526   return SetFlags;
527 }
528 
529 // Generate a subs, or sub and cmp, and a branch instead of an LE.
530 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
531   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
532 
533   MachineBasicBlock *MBB = MI->getParent();
534   // Create cmp
535   if (!SkipCmp) {
536     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
537                                       TII->get(ARM::t2CMPri));
538     MIB.addReg(ARM::LR);
539     MIB.addImm(0);
540     MIB.addImm(ARMCC::AL);
541     MIB.addReg(ARM::NoRegister);
542   }
543 
544   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
545   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
546     ARM::tBcc : ARM::t2Bcc;
547 
548   // Create bne
549   MachineInstrBuilder MIB =
550     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
551   MIB.add(MI->getOperand(1));   // branch target
552   MIB.addImm(ARMCC::NE);        // condition code
553   MIB.addReg(ARM::CPSR);
554   MI->eraseFromParent();
555 }
556 
557 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
558   MachineInstr *InsertPt = LoLoop.InsertPt;
559   MachineInstr *Start = LoLoop.Start;
560   MachineBasicBlock *MBB = InsertPt->getParent();
561   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
562   unsigned Opc = LoLoop.getStartOpcode();
563   MachineOperand &Count = LoLoop.getCount();
564 
565   MachineInstrBuilder MIB =
566     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
567 
568   MIB.addDef(ARM::LR);
569   MIB.add(Count);
570   if (!IsDo)
571     MIB.add(Start->getOperand(1));
572 
573   // When using tail-predication, try to delete the dead code that was used to
574   // calculate the number of loop iterations.
575   if (LoLoop.IsTailPredicationLegal()) {
576     SmallVector<MachineInstr*, 4> Killed;
577     SmallVector<MachineInstr*, 4> Dead;
578     if (auto *Def = RDA->getReachingMIDef(Start,
579                                           Start->getOperand(0).getReg())) {
580       Killed.push_back(Def);
581 
582       while (!Killed.empty()) {
583         MachineInstr *Def = Killed.back();
584         Killed.pop_back();
585         Dead.push_back(Def);
586         for (auto &MO : Def->operands()) {
587           if (!MO.isReg() || !MO.isKill())
588             continue;
589 
590           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
591           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
592             Killed.push_back(Kill);
593         }
594       }
595       for (auto *MI : Dead)
596         MI->eraseFromParent();
597     }
598   }
599 
600   // If we're inserting at a mov lr, then remove it as it's redundant.
601   if (InsertPt != Start)
602     InsertPt->eraseFromParent();
603   Start->eraseFromParent();
604   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
605   return &*MIB;
606 }
607 
608 // Goal is to optimise and clean-up these loops:
609 //
610 //   vector.body:
611 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
612 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
613 //     ..
614 //     $lr = MVE_DLSTP_32 renamable $r3
615 //
616 // The SUB is the old update of the loop iteration count expression, which
617 // is no longer needed. This sub is removed when the element count, which is in
618 // r3 in this example, is defined by an instruction in the loop, and it has
619 // no uses.
620 //
621 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
622   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
623   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
624 
625   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
626 
627   if (LoLoop.ML->getNumBlocks() != 1) {
628     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
629     return;
630   }
631 
632   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
633              LoLoop.VCTP->getOperand(1).dump());
634 
635   // Find the definition we are interested in removing, if there is one.
636   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
637   if (!Def)
638     return;
639 
640   // Bail if we define CPSR and it is not dead
641   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
642     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
643     return;
644   }
645 
646   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
647   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
648     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
649     return;
650   }
651 
652   // Bail if there are uses after this Def in the block.
653   SmallVector<MachineInstr*, 4> Uses;
654   RDA->getReachingLocalUses(Def, ElemCount, Uses);
655   if (Uses.size()) {
656     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
657     return;
658   }
659 
660   Uses.clear();
661   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
662 
663   // Remove Def if there are no uses, or if the only use is the VCTP
664   // instruction.
665   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
666     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
667                Def->dump());
668     Def->eraseFromParent();
669   }
670 }
671 
672 void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
673   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
674   LoLoop.VCTP->eraseFromParent();
675 
676   for (auto *MI : LoLoop.VPTUsers) {
677     if (MI->getOpcode() == ARM::MVE_VPST) {
678       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
679       MI->eraseFromParent();
680     } else {
681       unsigned OpNum = MI->getNumOperands() - 1;
682       assert((MI->getOperand(OpNum).isReg() &&
683               MI->getOperand(OpNum).getReg() == ARM::VPR) &&
684              "Expected VPR");
685       assert((MI->getOperand(OpNum-1).isImm() &&
686               MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
687              "Expected Then predicate");
688       MI->getOperand(OpNum-1).setImm(ARMVCC::None);
689       MI->getOperand(OpNum).setReg(0);
690       LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
691     }
692   }
693 }
694 
695 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
696 
697   // Combine the LoopDec and LoopEnd instructions into LE(TP).
698   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
699     MachineInstr *End = LoLoop.End;
700     MachineBasicBlock *MBB = End->getParent();
701     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
702       ARM::MVE_LETP : ARM::t2LEUpdate;
703     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
704                                       TII->get(Opc));
705     MIB.addDef(ARM::LR);
706     MIB.add(End->getOperand(0));
707     MIB.add(End->getOperand(1));
708     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
709 
710     LoLoop.End->eraseFromParent();
711     LoLoop.Dec->eraseFromParent();
712     return &*MIB;
713   };
714 
715   // TODO: We should be able to automatically remove these branches before we
716   // get here - probably by teaching analyzeBranch about the pseudo
717   // instructions.
718   // If there is an unconditional branch, after I, that just branches to the
719   // next block, remove it.
720   auto RemoveDeadBranch = [](MachineInstr *I) {
721     MachineBasicBlock *BB = I->getParent();
722     MachineInstr *Terminator = &BB->instr_back();
723     if (Terminator->isUnconditionalBranch() && I != Terminator) {
724       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
725       if (BB->isLayoutSuccessor(Succ)) {
726         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
727         Terminator->eraseFromParent();
728       }
729     }
730   };
731 
732   if (LoLoop.Revert) {
733     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
734       RevertWhile(LoLoop.Start);
735     else
736       LoLoop.Start->eraseFromParent();
737     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
738     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
739   } else {
740     LoLoop.Start = ExpandLoopStart(LoLoop);
741     RemoveDeadBranch(LoLoop.Start);
742     LoLoop.End = ExpandLoopEnd(LoLoop);
743     RemoveDeadBranch(LoLoop.End);
744     if (LoLoop.IsTailPredicationLegal()) {
745       RemoveLoopUpdate(LoLoop);
746       RemoveVPTBlocks(LoLoop);
747     }
748   }
749 }
750 
751 bool ARMLowOverheadLoops::RevertNonLoops() {
752   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
753   bool Changed = false;
754 
755   for (auto &MBB : *MF) {
756     SmallVector<MachineInstr*, 4> Starts;
757     SmallVector<MachineInstr*, 4> Decs;
758     SmallVector<MachineInstr*, 4> Ends;
759 
760     for (auto &I : MBB) {
761       if (isLoopStart(I))
762         Starts.push_back(&I);
763       else if (I.getOpcode() == ARM::t2LoopDec)
764         Decs.push_back(&I);
765       else if (I.getOpcode() == ARM::t2LoopEnd)
766         Ends.push_back(&I);
767     }
768 
769     if (Starts.empty() && Decs.empty() && Ends.empty())
770       continue;
771 
772     Changed = true;
773 
774     for (auto *Start : Starts) {
775       if (Start->getOpcode() == ARM::t2WhileLoopStart)
776         RevertWhile(Start);
777       else
778         Start->eraseFromParent();
779     }
780     for (auto *Dec : Decs)
781       RevertLoopDec(Dec);
782 
783     for (auto *End : Ends)
784       RevertLoopEnd(End);
785   }
786   return Changed;
787 }
788 
789 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
790   return new ARMLowOverheadLoops();
791 }
792