xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision cced971fd3d6713ec4989990e1b2f42c8539f0f3)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineLoopInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/Passes.h"
29 #include "llvm/CodeGen/ReachingDefAnalysis.h"
30 #include "llvm/MC/MCInstrDesc.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "arm-low-overhead-loops"
35 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
36 
37 namespace {
38 
39   struct LowOverheadLoop {
40 
41     MachineLoop *ML = nullptr;
42     MachineFunction *MF = nullptr;
43     MachineInstr *InsertPt = nullptr;
44     MachineInstr *Start = nullptr;
45     MachineInstr *Dec = nullptr;
46     MachineInstr *End = nullptr;
47     MachineInstr *VCTP = nullptr;
48     SmallVector<MachineInstr*, 4> VPTUsers;
49     bool Revert = false;
50     bool FoundOneVCTP = false;
51     bool CannotTailPredicate = false;
52 
53     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
54       MF = ML->getHeader()->getParent();
55     }
56 
57     // For now, only support one vctp instruction. If we find multiple then
58     // we shouldn't perform tail predication.
59     void addVCTP(MachineInstr *MI) {
60       if (!VCTP) {
61         VCTP = MI;
62         FoundOneVCTP = true;
63       } else
64         FoundOneVCTP = false;
65     }
66 
67     // Check that nothing else is writing to VPR and record any insts
68     // reading the VPR.
69     void ScanForVPR(MachineInstr *MI) {
70       for (auto &MO : MI->operands()) {
71         if (!MO.isReg() || MO.getReg() != ARM::VPR)
72           continue;
73         if (MO.isUse())
74           VPTUsers.push_back(MI);
75         if (MO.isDef()) {
76           CannotTailPredicate = true;
77           break;
78         }
79       }
80     }
81 
82     // If this is an MVE instruction, check that we know how to use tail
83     // predication with it.
84     void CheckTPValidity(MachineInstr *MI) {
85       if (CannotTailPredicate)
86         return;
87 
88       const MCInstrDesc &MCID = MI->getDesc();
89       uint64_t Flags = MCID.TSFlags;
90       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
91         return;
92 
93       if ((Flags & ARMII::ValidForTailPredication) == 0) {
94         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
95         CannotTailPredicate = true;
96       }
97     }
98 
99     bool IsTailPredicationLegal() const {
100       // For now, let's keep things really simple and only support a single
101       // block for tail predication.
102       return !Revert && FoundAllComponents() && FoundOneVCTP &&
103              !CannotTailPredicate && ML->getNumBlocks() == 1;
104     }
105 
106     // Is it safe to define LR with DLS/WLS?
107     // LR can be defined if it is the operand to start, because it's the same
108     // value, or if it's going to be equivalent to the operand to Start.
109     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
110 
111     // Check the branch targets are within range and we satisfy our
112     // restrictions.
113     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA);
114 
115     bool FoundAllComponents() const {
116       return Start && Dec && End;
117     }
118 
119     void dump() const {
120       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
121       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
122       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
123       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
124       if (!FoundAllComponents())
125         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
126       else if (!(Start && Dec && End))
127         dbgs() << "ARM Loops: Failed to find all loop components.\n";
128     }
129   };
130 
131   class ARMLowOverheadLoops : public MachineFunctionPass {
132     MachineFunction           *MF = nullptr;
133     ReachingDefAnalysis       *RDA = nullptr;
134     const ARMBaseInstrInfo    *TII = nullptr;
135     MachineRegisterInfo       *MRI = nullptr;
136     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
137 
138   public:
139     static char ID;
140 
141     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
142 
143     void getAnalysisUsage(AnalysisUsage &AU) const override {
144       AU.setPreservesCFG();
145       AU.addRequired<MachineLoopInfo>();
146       AU.addRequired<ReachingDefAnalysis>();
147       MachineFunctionPass::getAnalysisUsage(AU);
148     }
149 
150     bool runOnMachineFunction(MachineFunction &MF) override;
151 
152     MachineFunctionProperties getRequiredProperties() const override {
153       return MachineFunctionProperties().set(
154           MachineFunctionProperties::Property::NoVRegs).set(
155           MachineFunctionProperties::Property::TracksLiveness);
156     }
157 
158     StringRef getPassName() const override {
159       return ARM_LOW_OVERHEAD_LOOPS_NAME;
160     }
161 
162   private:
163     bool ProcessLoop(MachineLoop *ML);
164 
165     bool RevertNonLoops();
166 
167     void RevertWhile(MachineInstr *MI) const;
168 
169     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
170 
171     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
172 
173     void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
174 
175     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
176 
177     void Expand(LowOverheadLoop &LoLoop);
178 
179   };
180 }
181 
182 char ARMLowOverheadLoops::ID = 0;
183 
184 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
185                 false, false)
186 
187 static bool IsLoopStart(MachineInstr &MI) {
188   return MI.getOpcode() == ARM::t2DoLoopStart ||
189          MI.getOpcode() == ARM::t2WhileLoopStart;
190 }
191 
192 static bool IsVCTP(MachineInstr *MI) {
193   switch (MI->getOpcode()) {
194   default:
195     break;
196   case ARM::MVE_VCTP8:
197   case ARM::MVE_VCTP16:
198   case ARM::MVE_VCTP32:
199   case ARM::MVE_VCTP64:
200     return true;
201   }
202   return false;
203 }
204 
205 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
206   // We can define LR because LR already contains the same value.
207   if (Start->getOperand(0).getReg() == ARM::LR)
208     return Start;
209 
210   unsigned CountReg = Start->getOperand(0).getReg();
211   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
212     return MI->getOpcode() == ARM::tMOVr &&
213            MI->getOperand(0).getReg() == ARM::LR &&
214            MI->getOperand(1).getReg() == CountReg &&
215            MI->getOperand(2).getImm() == ARMCC::AL;
216    };
217 
218   MachineBasicBlock *MBB = Start->getParent();
219 
220   // Find an insertion point:
221   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
222   //   to Count before Start, we can insert at that mov.
223   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
224   //   to Count after Start, we can insert at that mov.
225   if (auto *LRDef = RDA->getReachingMIDef(&MBB->back(), ARM::LR)) {
226     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
227       return LRDef;
228   }
229 
230   // We've found no suitable LR def and Start doesn't use LR directly. Can we
231   // just define LR anyway?
232   if (!RDA->isRegUsedAfter(Start, ARM::LR))
233     return Start;
234 
235   return nullptr;
236 }
237 
238 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
239                                     ReachingDefAnalysis *RDA) {
240   if (Revert)
241     return;
242 
243   if (!End->getOperand(1).isMBB())
244     report_fatal_error("Expected LoopEnd to target basic block");
245 
246   // TODO Maybe there's cases where the target doesn't have to be the header,
247   // but for now be safe and revert.
248   if (End->getOperand(1).getMBB() != ML->getHeader()) {
249     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
250     Revert = true;
251     return;
252   }
253 
254   // The WLS and LE instructions have 12-bits for the label offset. WLS
255   // requires a positive offset, while LE uses negative.
256   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
257       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
258     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
259     Revert = true;
260     return;
261   }
262 
263   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
264       (BBUtils->getOffsetOf(Start) >
265        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
266        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
267     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
268     Revert = true;
269     return;
270   }
271 
272   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
273   if (!InsertPt) {
274     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
275     Revert = true;
276   } else
277     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
278 
279   LLVM_DEBUG(if (IsTailPredicationLegal()) {
280                dbgs() << "ARM Loops: Will use tail predication to convert:\n";
281                for (auto *MI : VPTUsers)
282                  dbgs() << " - " << *MI;
283              });
284 }
285 
286 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
287   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
288   if (!ST.hasLOB())
289     return false;
290 
291   MF = &mf;
292   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
293 
294   auto &MLI = getAnalysis<MachineLoopInfo>();
295   RDA = &getAnalysis<ReachingDefAnalysis>();
296   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
297   MRI = &MF->getRegInfo();
298   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
299   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
300   BBUtils->computeAllBlockSizes();
301   BBUtils->adjustBBOffsetsAfter(&MF->front());
302 
303   bool Changed = false;
304   for (auto ML : MLI) {
305     if (!ML->getParentLoop())
306       Changed |= ProcessLoop(ML);
307   }
308   Changed |= RevertNonLoops();
309   return Changed;
310 }
311 
312 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
313 
314   bool Changed = false;
315 
316   // Process inner loops first.
317   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
318     Changed |= ProcessLoop(*I);
319 
320   LLVM_DEBUG(dbgs() << "ARM Loops: Processing " << *ML);
321 
322   // Search the given block for a loop start instruction. If one isn't found,
323   // and there's only one predecessor block, search that one too.
324   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
325     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
326     for (auto &MI : *MBB) {
327       if (IsLoopStart(MI))
328         return &MI;
329     }
330     if (MBB->pred_size() == 1)
331       return SearchForStart(*MBB->pred_begin());
332     return nullptr;
333   };
334 
335   LowOverheadLoop LoLoop(ML);
336   // Search the preheader for the start intrinsic, or look through the
337   // predecessors of the header to find exactly one set.iterations intrinsic.
338   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
339   // with potentially multiple set.loop.iterations, so we need to enable this.
340   if (auto *Preheader = ML->getLoopPreheader())
341     LoLoop.Start = SearchForStart(Preheader);
342   else {
343     LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find loop preheader!\n"
344                << " - Performing manual predecessor search.\n");
345     MachineBasicBlock *Pred = nullptr;
346     for (auto *MBB : ML->getHeader()->predecessors()) {
347       if (!ML->contains(MBB)) {
348         if (Pred) {
349           LLVM_DEBUG(dbgs() << " - Found multiple out-of-loop preds.\n");
350           LoLoop.Start = nullptr;
351           break;
352         }
353         Pred = MBB;
354         LoLoop.Start = SearchForStart(MBB);
355       }
356     }
357   }
358 
359   // Find the low-overhead loop components and decide whether or not to fall
360   // back to a normal loop. Also look for a vctp instructions and decide
361   // whether we can convert that predicate using tail predication.
362   for (auto *MBB : reverse(ML->getBlocks())) {
363     for (auto &MI : *MBB) {
364       if (MI.getOpcode() == ARM::t2LoopDec)
365         LoLoop.Dec = &MI;
366       else if (MI.getOpcode() == ARM::t2LoopEnd)
367         LoLoop.End = &MI;
368       else if (IsLoopStart(MI))
369         LoLoop.Start = &MI;
370       else if (IsVCTP(&MI))
371         LoLoop.addVCTP(&MI);
372       else if (MI.getDesc().isCall()) {
373         // TODO: Though the call will require LE to execute again, does this
374         // mean we should revert? Always executing LE hopefully should be
375         // faster than performing a sub,cmp,br or even subs,br.
376         LoLoop.Revert = true;
377         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
378       } else {
379         // Once we've found a vctp, record the users of vpr and check there's
380         // no more vpr defs.
381         if (LoLoop.FoundOneVCTP)
382           LoLoop.ScanForVPR(&MI);
383         // Check we know how to tail predicate any mve instructions.
384         LoLoop.CheckTPValidity(&MI);
385       }
386 
387       // We need to ensure that LR is not used or defined inbetween LoopDec and
388       // LoopEnd.
389       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
390         continue;
391 
392       // If we find that LR has been written or read between LoopDec and
393       // LoopEnd, expect that the decremented value is being used else where.
394       // Because this value isn't actually going to be produced until the
395       // latch, by LE, we would need to generate a real sub. The value is also
396       // likely to be copied/reloaded for use of LoopEnd - in which in case
397       // we'd need to perform an add because it gets subtracted again by LE!
398       // The other option is to then generate the other form of LE which doesn't
399       // perform the sub.
400       for (auto &MO : MI.operands()) {
401         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
402             MO.getReg() == ARM::LR) {
403           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
404           LoLoop.Revert = true;
405           break;
406         }
407       }
408     }
409   }
410 
411   LLVM_DEBUG(LoLoop.dump());
412   if (!LoLoop.FoundAllComponents())
413     return false;
414 
415   LoLoop.CheckLegality(BBUtils.get(), RDA);
416   Expand(LoLoop);
417   return true;
418 }
419 
420 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
421 // beq that branches to the exit branch.
422 // TODO: We could also try to generate a cbz if the value in LR is also in
423 // another low register.
424 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
425   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
426   MachineBasicBlock *MBB = MI->getParent();
427   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
428                                     TII->get(ARM::t2CMPri));
429   MIB.add(MI->getOperand(0));
430   MIB.addImm(0);
431   MIB.addImm(ARMCC::AL);
432   MIB.addReg(ARM::NoRegister);
433 
434   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
435   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
436     ARM::tBcc : ARM::t2Bcc;
437 
438   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
439   MIB.add(MI->getOperand(1));   // branch target
440   MIB.addImm(ARMCC::EQ);        // condition code
441   MIB.addReg(ARM::CPSR);
442   MI->eraseFromParent();
443 }
444 
445 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
446                                         bool SetFlags) const {
447   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
448   MachineBasicBlock *MBB = MI->getParent();
449 
450   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
451   if (SetFlags &&
452       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
453        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
454       SetFlags = false;
455 
456   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
457                                     TII->get(ARM::t2SUBri));
458   MIB.addDef(ARM::LR);
459   MIB.add(MI->getOperand(1));
460   MIB.add(MI->getOperand(2));
461   MIB.addImm(ARMCC::AL);
462   MIB.addReg(0);
463 
464   if (SetFlags) {
465     MIB.addReg(ARM::CPSR);
466     MIB->getOperand(5).setIsDef(true);
467   } else
468     MIB.addReg(0);
469 
470   MI->eraseFromParent();
471   return SetFlags;
472 }
473 
474 // Generate a subs, or sub and cmp, and a branch instead of an LE.
475 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
476   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
477 
478   MachineBasicBlock *MBB = MI->getParent();
479   // Create cmp
480   if (!SkipCmp) {
481     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
482                                       TII->get(ARM::t2CMPri));
483     MIB.addReg(ARM::LR);
484     MIB.addImm(0);
485     MIB.addImm(ARMCC::AL);
486     MIB.addReg(ARM::NoRegister);
487   }
488 
489   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
490   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
491     ARM::tBcc : ARM::t2Bcc;
492 
493   // Create bne
494   MachineInstrBuilder MIB =
495     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
496   MIB.add(MI->getOperand(1));   // branch target
497   MIB.addImm(ARMCC::NE);        // condition code
498   MIB.addReg(ARM::CPSR);
499   MI->eraseFromParent();
500 }
501 
502 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
503   MachineInstr *InsertPt = LoLoop.InsertPt;
504   MachineInstr *Start = LoLoop.Start;
505   MachineBasicBlock *MBB = InsertPt->getParent();
506   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
507   unsigned Opc = 0;
508 
509   if (!LoLoop.IsTailPredicationLegal())
510     Opc = IsDo ? ARM::t2DLS : ARM::t2WLS;
511   else {
512     switch (LoLoop.VCTP->getOpcode()) {
513     case ARM::MVE_VCTP8:
514       Opc = IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8;
515       break;
516     case ARM::MVE_VCTP16:
517       Opc = IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16;
518       break;
519     case ARM::MVE_VCTP32:
520       Opc = IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32;
521       break;
522     case ARM::MVE_VCTP64:
523       Opc = IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64;
524       break;
525     }
526   }
527 
528   MachineInstrBuilder MIB =
529     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
530 
531   MIB.addDef(ARM::LR);
532   MIB.add(Start->getOperand(0));
533   if (!IsDo)
534     MIB.add(Start->getOperand(1));
535 
536   if (InsertPt != Start)
537     InsertPt->eraseFromParent();
538   Start->eraseFromParent();
539   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
540   return &*MIB;
541 }
542 
543 void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
544   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
545   LoLoop.VCTP->eraseFromParent();
546 
547   for (auto *MI : LoLoop.VPTUsers) {
548     if (MI->getOpcode() == ARM::MVE_VPST) {
549       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
550       MI->eraseFromParent();
551     } else {
552       unsigned OpNum = MI->getNumOperands() - 1;
553       assert((MI->getOperand(OpNum).isReg() &&
554               MI->getOperand(OpNum).getReg() == ARM::VPR) &&
555              "Expected VPR");
556       assert((MI->getOperand(OpNum-1).isImm() &&
557               MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
558              "Expected Then predicate");
559       MI->getOperand(OpNum-1).setImm(ARMVCC::None);
560       MI->getOperand(OpNum).setReg(0);
561       LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
562     }
563   }
564 }
565 
566 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
567 
568   // Combine the LoopDec and LoopEnd instructions into LE(TP).
569   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
570     MachineInstr *End = LoLoop.End;
571     MachineBasicBlock *MBB = End->getParent();
572     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
573       ARM::MVE_LETP : ARM::t2LEUpdate;
574     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
575                                       TII->get(Opc));
576     MIB.addDef(ARM::LR);
577     MIB.add(End->getOperand(0));
578     MIB.add(End->getOperand(1));
579     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
580 
581     LoLoop.End->eraseFromParent();
582     LoLoop.Dec->eraseFromParent();
583     return &*MIB;
584   };
585 
586   // TODO: We should be able to automatically remove these branches before we
587   // get here - probably by teaching analyzeBranch about the pseudo
588   // instructions.
589   // If there is an unconditional branch, after I, that just branches to the
590   // next block, remove it.
591   auto RemoveDeadBranch = [](MachineInstr *I) {
592     MachineBasicBlock *BB = I->getParent();
593     MachineInstr *Terminator = &BB->instr_back();
594     if (Terminator->isUnconditionalBranch() && I != Terminator) {
595       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
596       if (BB->isLayoutSuccessor(Succ)) {
597         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
598         Terminator->eraseFromParent();
599       }
600     }
601   };
602 
603   if (LoLoop.Revert) {
604     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
605       RevertWhile(LoLoop.Start);
606     else
607       LoLoop.Start->eraseFromParent();
608     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
609     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
610   } else {
611     LoLoop.Start = ExpandLoopStart(LoLoop);
612     RemoveDeadBranch(LoLoop.Start);
613     LoLoop.End = ExpandLoopEnd(LoLoop);
614     RemoveDeadBranch(LoLoop.End);
615     if (LoLoop.IsTailPredicationLegal())
616       RemoveVPTBlocks(LoLoop);
617   }
618 }
619 
620 bool ARMLowOverheadLoops::RevertNonLoops() {
621   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
622   bool Changed = false;
623 
624   for (auto &MBB : *MF) {
625     SmallVector<MachineInstr*, 4> Starts;
626     SmallVector<MachineInstr*, 4> Decs;
627     SmallVector<MachineInstr*, 4> Ends;
628 
629     for (auto &I : MBB) {
630       if (IsLoopStart(I))
631         Starts.push_back(&I);
632       else if (I.getOpcode() == ARM::t2LoopDec)
633         Decs.push_back(&I);
634       else if (I.getOpcode() == ARM::t2LoopEnd)
635         Ends.push_back(&I);
636     }
637 
638     if (Starts.empty() && Decs.empty() && Ends.empty())
639       continue;
640 
641     Changed = true;
642 
643     for (auto *Start : Starts) {
644       if (Start->getOpcode() == ARM::t2WhileLoopStart)
645         RevertWhile(Start);
646       else
647         Start->eraseFromParent();
648     }
649     for (auto *Dec : Decs)
650       RevertLoopDec(Dec);
651 
652     for (auto *End : Ends)
653       RevertLoopEnd(End);
654   }
655   return Changed;
656 }
657 
658 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
659   return new ARMLowOverheadLoops();
660 }
661