xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 28166816b05aebb3154e5f8a28b3ef447cce8471)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineLoopInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/Passes.h"
29 #include "llvm/CodeGen/ReachingDefAnalysis.h"
30 #include "llvm/MC/MCInstrDesc.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "arm-low-overhead-loops"
35 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
36 
37 namespace {
38 
39   struct LowOverheadLoop {
40 
41     MachineLoop *ML = nullptr;
42     MachineFunction *MF = nullptr;
43     MachineInstr *InsertPt = nullptr;
44     MachineInstr *Start = nullptr;
45     MachineInstr *Dec = nullptr;
46     MachineInstr *End = nullptr;
47     MachineInstr *VCTP = nullptr;
48     SmallVector<MachineInstr*, 4> VPTUsers;
49     bool Revert = false;
50     bool FoundOneVCTP = false;
51     bool CannotTailPredicate = false;
52 
53     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
54       MF = ML->getHeader()->getParent();
55     }
56 
57     // For now, only support one vctp instruction. If we find multiple then
58     // we shouldn't perform tail predication.
59     void addVCTP(MachineInstr *MI) {
60       if (!VCTP) {
61         VCTP = MI;
62         FoundOneVCTP = true;
63       } else
64         FoundOneVCTP = false;
65     }
66 
67     // Check that nothing else is writing to VPR and record any insts
68     // reading the VPR.
69     void ScanForVPR(MachineInstr *MI) {
70       for (auto &MO : MI->operands()) {
71         if (!MO.isReg() || MO.getReg() != ARM::VPR)
72           continue;
73         if (MO.isUse())
74           VPTUsers.push_back(MI);
75         if (MO.isDef()) {
76           CannotTailPredicate = true;
77           break;
78         }
79       }
80     }
81 
82     // If this is an MVE instruction, check that we know how to use tail
83     // predication with it.
84     void CheckTPValidity(MachineInstr *MI) {
85       if (CannotTailPredicate)
86         return;
87 
88       const MCInstrDesc &MCID = MI->getDesc();
89       uint64_t Flags = MCID.TSFlags;
90       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
91         return;
92 
93       if ((Flags & ARMII::ValidForTailPredication) == 0) {
94         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
95         CannotTailPredicate = true;
96       }
97     }
98 
99     bool IsTailPredicationLegal() const {
100       // For now, let's keep things really simple and only support a single
101       // block for tail predication.
102       return !Revert && FoundAllComponents() && FoundOneVCTP &&
103              !CannotTailPredicate && ML->getNumBlocks() == 1;
104     }
105 
106     // Is it safe to define LR with DLS/WLS?
107     // LR can be defined if it is the operand to start, because it's the same
108     // value, or if it's going to be equivalent to the operand to Start.
109     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
110 
111     // Check the branch targets are within range and we satisfy our
112     // restrictions.
113     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
114                        MachineLoopInfo *MLI);
115 
116     bool FoundAllComponents() const {
117       return Start && Dec && End;
118     }
119 
120     // Return the loop iteration count, or the number of elements if we're tail
121     // predicating.
122     MachineOperand &getCount() {
123       return IsTailPredicationLegal() ?
124         VCTP->getOperand(1) : Start->getOperand(0);
125     }
126 
127     unsigned getStartOpcode() const {
128       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
129       if (!IsTailPredicationLegal())
130         return IsDo ? ARM::t2DLS : ARM::t2WLS;
131 
132       switch (VCTP->getOpcode()) {
133       default:
134         llvm_unreachable("unhandled vctp opcode");
135         break;
136       case ARM::MVE_VCTP8:
137         return IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8;
138       case ARM::MVE_VCTP16:
139         return IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16;
140       case ARM::MVE_VCTP32:
141         return IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32;
142       case ARM::MVE_VCTP64:
143         return IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64;
144       }
145       return 0;
146     }
147 
148     void dump() const {
149       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
150       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
151       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
152       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
153       if (!FoundAllComponents())
154         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
155       else if (!(Start && Dec && End))
156         dbgs() << "ARM Loops: Failed to find all loop components.\n";
157     }
158   };
159 
160   class ARMLowOverheadLoops : public MachineFunctionPass {
161     MachineFunction           *MF = nullptr;
162     MachineLoopInfo           *MLI = nullptr;
163     ReachingDefAnalysis       *RDA = nullptr;
164     const ARMBaseInstrInfo    *TII = nullptr;
165     MachineRegisterInfo       *MRI = nullptr;
166     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
167 
168   public:
169     static char ID;
170 
171     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
172 
173     void getAnalysisUsage(AnalysisUsage &AU) const override {
174       AU.setPreservesCFG();
175       AU.addRequired<MachineLoopInfo>();
176       AU.addRequired<ReachingDefAnalysis>();
177       MachineFunctionPass::getAnalysisUsage(AU);
178     }
179 
180     bool runOnMachineFunction(MachineFunction &MF) override;
181 
182     MachineFunctionProperties getRequiredProperties() const override {
183       return MachineFunctionProperties().set(
184           MachineFunctionProperties::Property::NoVRegs).set(
185           MachineFunctionProperties::Property::TracksLiveness);
186     }
187 
188     StringRef getPassName() const override {
189       return ARM_LOW_OVERHEAD_LOOPS_NAME;
190     }
191 
192   private:
193     bool ProcessLoop(MachineLoop *ML);
194 
195     bool RevertNonLoops();
196 
197     void RevertWhile(MachineInstr *MI) const;
198 
199     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
200 
201     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
202 
203     void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
204 
205     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
206 
207     void Expand(LowOverheadLoop &LoLoop);
208 
209   };
210 }
211 
212 char ARMLowOverheadLoops::ID = 0;
213 
214 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
215                 false, false)
216 
217 static bool IsLoopStart(MachineInstr &MI) {
218   return MI.getOpcode() == ARM::t2DoLoopStart ||
219          MI.getOpcode() == ARM::t2WhileLoopStart;
220 }
221 
222 static bool IsVCTP(MachineInstr *MI) {
223   switch (MI->getOpcode()) {
224   default:
225     break;
226   case ARM::MVE_VCTP8:
227   case ARM::MVE_VCTP16:
228   case ARM::MVE_VCTP32:
229   case ARM::MVE_VCTP64:
230     return true;
231   }
232   return false;
233 }
234 
235 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
236   // We can define LR because LR already contains the same value.
237   if (Start->getOperand(0).getReg() == ARM::LR)
238     return Start;
239 
240   unsigned CountReg = Start->getOperand(0).getReg();
241   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
242     return MI->getOpcode() == ARM::tMOVr &&
243            MI->getOperand(0).getReg() == ARM::LR &&
244            MI->getOperand(1).getReg() == CountReg &&
245            MI->getOperand(2).getImm() == ARMCC::AL;
246    };
247 
248   MachineBasicBlock *MBB = Start->getParent();
249 
250   // Find an insertion point:
251   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
252   //   to Count before Start, we can insert at that mov.
253   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
254   //   to Count after Start, we can insert at that mov.
255   if (auto *LRDef = RDA->getReachingMIDef(&MBB->back(), ARM::LR)) {
256     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
257       return LRDef;
258   }
259 
260   // We've found no suitable LR def and Start doesn't use LR directly. Can we
261   // just define LR anyway?
262   if (!RDA->isRegUsedAfter(Start, ARM::LR))
263     return Start;
264 
265   return nullptr;
266 }
267 
268 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
269                                     ReachingDefAnalysis *RDA,
270                                     MachineLoopInfo *MLI) {
271   if (Revert)
272     return;
273 
274   if (!End->getOperand(1).isMBB())
275     report_fatal_error("Expected LoopEnd to target basic block");
276 
277   // TODO Maybe there's cases where the target doesn't have to be the header,
278   // but for now be safe and revert.
279   if (End->getOperand(1).getMBB() != ML->getHeader()) {
280     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
281     Revert = true;
282     return;
283   }
284 
285   // The WLS and LE instructions have 12-bits for the label offset. WLS
286   // requires a positive offset, while LE uses negative.
287   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
288       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
289     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
290     Revert = true;
291     return;
292   }
293 
294   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
295       (BBUtils->getOffsetOf(Start) >
296        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
297        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
298     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
299     Revert = true;
300     return;
301   }
302 
303   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
304   if (!InsertPt) {
305     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
306     Revert = true;
307     return;
308   } else
309     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
310 
311   // For tail predication, we need to provide the number of elements, instead
312   // of the iteration count, to the loop start instruction. The number of
313   // elements is provided to the vctp instruction, so we need to check that
314   // we can use this register at InsertPt.
315   if (!IsTailPredicationLegal())
316     return;
317 
318   Register NumElements = VCTP->getOperand(1).getReg();
319 
320   // If the register is defined within loop, then we can't perform TP.
321   // TODO: Check whether this is just a mov of a register that would be
322   // available.
323   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
324     CannotTailPredicate = true;
325     return;
326   }
327 
328   // We can't perform TP if the register does not hold the same value at
329   // InsertPt as the liveout value.
330   MachineBasicBlock *InsertBB = InsertPt->getParent();
331   if  (!RDA->hasSameReachingDef(InsertPt, &InsertBB->back(),
332                                 NumElements)) {
333     CannotTailPredicate = true;
334     return;
335   }
336 
337   // Especially in the case of while loops, InsertBB may not be the
338   // preheader, so we need to check that the register isn't redefined
339   // before entering the loop.
340   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
341                                       Register NumElements) {
342     // NumElements is redefined in this block.
343     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
344       return true;
345 
346     // Don't continue searching up through multiple predecessors.
347     if (MBB->pred_size() > 1)
348       return true;
349 
350     return false;
351   };
352 
353   // First, find the block that looks like the preheader.
354   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
355   if (!MBB) {
356     CannotTailPredicate = true;
357     return;
358   }
359 
360   // Then search backwards for a def, until we get to InsertBB.
361   while (MBB != InsertBB) {
362     CannotTailPredicate = CannotProvideElements(MBB, NumElements);
363     if (CannotTailPredicate)
364       return;
365     MBB = *MBB->pred_begin();
366   }
367 
368   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n";
369                for (auto *MI : VPTUsers)
370                  dbgs() << " - " << *MI;);
371 }
372 
373 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
374   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
375   if (!ST.hasLOB())
376     return false;
377 
378   MF = &mf;
379   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
380 
381   MLI = &getAnalysis<MachineLoopInfo>();
382   RDA = &getAnalysis<ReachingDefAnalysis>();
383   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
384   MRI = &MF->getRegInfo();
385   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
386   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
387   BBUtils->computeAllBlockSizes();
388   BBUtils->adjustBBOffsetsAfter(&MF->front());
389 
390   bool Changed = false;
391   for (auto ML : *MLI) {
392     if (!ML->getParentLoop())
393       Changed |= ProcessLoop(ML);
394   }
395   Changed |= RevertNonLoops();
396   return Changed;
397 }
398 
399 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
400 
401   bool Changed = false;
402 
403   // Process inner loops first.
404   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
405     Changed |= ProcessLoop(*I);
406 
407   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
408              if (auto *Preheader = ML->getLoopPreheader())
409                dbgs() << " - " << Preheader->getName() << "\n";
410              else if (auto *Preheader = MLI->findLoopPreheader(ML))
411                dbgs() << " - " << Preheader->getName() << "\n";
412              for (auto *MBB : ML->getBlocks())
413                dbgs() << " - " << MBB->getName() << "\n";
414             );
415 
416   // Search the given block for a loop start instruction. If one isn't found,
417   // and there's only one predecessor block, search that one too.
418   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
419     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
420     for (auto &MI : *MBB) {
421       if (IsLoopStart(MI))
422         return &MI;
423     }
424     if (MBB->pred_size() == 1)
425       return SearchForStart(*MBB->pred_begin());
426     return nullptr;
427   };
428 
429   LowOverheadLoop LoLoop(ML);
430   // Search the preheader for the start intrinsic.
431   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
432   // with potentially multiple set.loop.iterations, so we need to enable this.
433   if (auto *Preheader = ML->getLoopPreheader())
434     LoLoop.Start = SearchForStart(Preheader);
435   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
436     LoLoop.Start = SearchForStart(Preheader);
437   else
438     return false;
439 
440   // Find the low-overhead loop components and decide whether or not to fall
441   // back to a normal loop. Also look for a vctp instructions and decide
442   // whether we can convert that predicate using tail predication.
443   for (auto *MBB : reverse(ML->getBlocks())) {
444     for (auto &MI : *MBB) {
445       if (MI.getOpcode() == ARM::t2LoopDec)
446         LoLoop.Dec = &MI;
447       else if (MI.getOpcode() == ARM::t2LoopEnd)
448         LoLoop.End = &MI;
449       else if (IsLoopStart(MI))
450         LoLoop.Start = &MI;
451       else if (IsVCTP(&MI))
452         LoLoop.addVCTP(&MI);
453       else if (MI.getDesc().isCall()) {
454         // TODO: Though the call will require LE to execute again, does this
455         // mean we should revert? Always executing LE hopefully should be
456         // faster than performing a sub,cmp,br or even subs,br.
457         LoLoop.Revert = true;
458         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
459       } else {
460         // Once we've found a vctp, record the users of vpr and check there's
461         // no more vpr defs.
462         if (LoLoop.FoundOneVCTP)
463           LoLoop.ScanForVPR(&MI);
464         // Check we know how to tail predicate any mve instructions.
465         LoLoop.CheckTPValidity(&MI);
466       }
467 
468       // We need to ensure that LR is not used or defined inbetween LoopDec and
469       // LoopEnd.
470       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
471         continue;
472 
473       // If we find that LR has been written or read between LoopDec and
474       // LoopEnd, expect that the decremented value is being used else where.
475       // Because this value isn't actually going to be produced until the
476       // latch, by LE, we would need to generate a real sub. The value is also
477       // likely to be copied/reloaded for use of LoopEnd - in which in case
478       // we'd need to perform an add because it gets subtracted again by LE!
479       // The other option is to then generate the other form of LE which doesn't
480       // perform the sub.
481       for (auto &MO : MI.operands()) {
482         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
483             MO.getReg() == ARM::LR) {
484           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
485           LoLoop.Revert = true;
486           break;
487         }
488       }
489     }
490   }
491 
492   LLVM_DEBUG(LoLoop.dump());
493   if (!LoLoop.FoundAllComponents())
494     return false;
495 
496   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
497   Expand(LoLoop);
498   return true;
499 }
500 
501 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
502 // beq that branches to the exit branch.
503 // TODO: We could also try to generate a cbz if the value in LR is also in
504 // another low register.
505 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
506   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
507   MachineBasicBlock *MBB = MI->getParent();
508   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
509                                     TII->get(ARM::t2CMPri));
510   MIB.add(MI->getOperand(0));
511   MIB.addImm(0);
512   MIB.addImm(ARMCC::AL);
513   MIB.addReg(ARM::NoRegister);
514 
515   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
516   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
517     ARM::tBcc : ARM::t2Bcc;
518 
519   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
520   MIB.add(MI->getOperand(1));   // branch target
521   MIB.addImm(ARMCC::EQ);        // condition code
522   MIB.addReg(ARM::CPSR);
523   MI->eraseFromParent();
524 }
525 
526 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
527                                         bool SetFlags) const {
528   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
529   MachineBasicBlock *MBB = MI->getParent();
530 
531   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
532   if (SetFlags &&
533       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
534        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
535       SetFlags = false;
536 
537   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
538                                     TII->get(ARM::t2SUBri));
539   MIB.addDef(ARM::LR);
540   MIB.add(MI->getOperand(1));
541   MIB.add(MI->getOperand(2));
542   MIB.addImm(ARMCC::AL);
543   MIB.addReg(0);
544 
545   if (SetFlags) {
546     MIB.addReg(ARM::CPSR);
547     MIB->getOperand(5).setIsDef(true);
548   } else
549     MIB.addReg(0);
550 
551   MI->eraseFromParent();
552   return SetFlags;
553 }
554 
555 // Generate a subs, or sub and cmp, and a branch instead of an LE.
556 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
557   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
558 
559   MachineBasicBlock *MBB = MI->getParent();
560   // Create cmp
561   if (!SkipCmp) {
562     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
563                                       TII->get(ARM::t2CMPri));
564     MIB.addReg(ARM::LR);
565     MIB.addImm(0);
566     MIB.addImm(ARMCC::AL);
567     MIB.addReg(ARM::NoRegister);
568   }
569 
570   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
571   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
572     ARM::tBcc : ARM::t2Bcc;
573 
574   // Create bne
575   MachineInstrBuilder MIB =
576     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
577   MIB.add(MI->getOperand(1));   // branch target
578   MIB.addImm(ARMCC::NE);        // condition code
579   MIB.addReg(ARM::CPSR);
580   MI->eraseFromParent();
581 }
582 
583 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
584   MachineInstr *InsertPt = LoLoop.InsertPt;
585   MachineInstr *Start = LoLoop.Start;
586   MachineBasicBlock *MBB = InsertPt->getParent();
587   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
588   unsigned Opc = LoLoop.getStartOpcode();
589   MachineOperand &Count = LoLoop.getCount();
590 
591   MachineInstrBuilder MIB =
592     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
593 
594   MIB.addDef(ARM::LR);
595   MIB.add(Count);
596   if (!IsDo)
597     MIB.add(Start->getOperand(1));
598 
599   // When using tail-predication, try to delete the dead code that was used to
600   // calculate the number of loop iterations.
601   if (LoLoop.IsTailPredicationLegal()) {
602     SmallVector<MachineInstr*, 4> Killed;
603     SmallVector<MachineInstr*, 4> Dead;
604     if (auto *Def = RDA->getReachingMIDef(Start,
605                                           Start->getOperand(0).getReg())) {
606       Killed.push_back(Def);
607 
608       while (!Killed.empty()) {
609         MachineInstr *Def = Killed.back();
610         Killed.pop_back();
611         Dead.push_back(Def);
612         for (auto &MO : Def->operands()) {
613           if (!MO.isReg() || !MO.isKill())
614             continue;
615 
616           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
617           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
618             Killed.push_back(Kill);
619         }
620       }
621       for (auto *MI : Dead)
622         MI->eraseFromParent();
623     }
624   }
625 
626   // If we're inserting at a mov lr, then remove it as it's redundant.
627   if (InsertPt != Start)
628     InsertPt->eraseFromParent();
629   Start->eraseFromParent();
630   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
631   return &*MIB;
632 }
633 
634 void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
635   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
636   LoLoop.VCTP->eraseFromParent();
637 
638   for (auto *MI : LoLoop.VPTUsers) {
639     if (MI->getOpcode() == ARM::MVE_VPST) {
640       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
641       MI->eraseFromParent();
642     } else {
643       unsigned OpNum = MI->getNumOperands() - 1;
644       assert((MI->getOperand(OpNum).isReg() &&
645               MI->getOperand(OpNum).getReg() == ARM::VPR) &&
646              "Expected VPR");
647       assert((MI->getOperand(OpNum-1).isImm() &&
648               MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
649              "Expected Then predicate");
650       MI->getOperand(OpNum-1).setImm(ARMVCC::None);
651       MI->getOperand(OpNum).setReg(0);
652       LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
653     }
654   }
655 }
656 
657 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
658 
659   // Combine the LoopDec and LoopEnd instructions into LE(TP).
660   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
661     MachineInstr *End = LoLoop.End;
662     MachineBasicBlock *MBB = End->getParent();
663     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
664       ARM::MVE_LETP : ARM::t2LEUpdate;
665     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
666                                       TII->get(Opc));
667     MIB.addDef(ARM::LR);
668     MIB.add(End->getOperand(0));
669     MIB.add(End->getOperand(1));
670     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
671 
672     LoLoop.End->eraseFromParent();
673     LoLoop.Dec->eraseFromParent();
674     return &*MIB;
675   };
676 
677   // TODO: We should be able to automatically remove these branches before we
678   // get here - probably by teaching analyzeBranch about the pseudo
679   // instructions.
680   // If there is an unconditional branch, after I, that just branches to the
681   // next block, remove it.
682   auto RemoveDeadBranch = [](MachineInstr *I) {
683     MachineBasicBlock *BB = I->getParent();
684     MachineInstr *Terminator = &BB->instr_back();
685     if (Terminator->isUnconditionalBranch() && I != Terminator) {
686       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
687       if (BB->isLayoutSuccessor(Succ)) {
688         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
689         Terminator->eraseFromParent();
690       }
691     }
692   };
693 
694   if (LoLoop.Revert) {
695     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
696       RevertWhile(LoLoop.Start);
697     else
698       LoLoop.Start->eraseFromParent();
699     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
700     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
701   } else {
702     LoLoop.Start = ExpandLoopStart(LoLoop);
703     RemoveDeadBranch(LoLoop.Start);
704     LoLoop.End = ExpandLoopEnd(LoLoop);
705     RemoveDeadBranch(LoLoop.End);
706     if (LoLoop.IsTailPredicationLegal())
707       RemoveVPTBlocks(LoLoop);
708   }
709 }
710 
711 bool ARMLowOverheadLoops::RevertNonLoops() {
712   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
713   bool Changed = false;
714 
715   for (auto &MBB : *MF) {
716     SmallVector<MachineInstr*, 4> Starts;
717     SmallVector<MachineInstr*, 4> Decs;
718     SmallVector<MachineInstr*, 4> Ends;
719 
720     for (auto &I : MBB) {
721       if (IsLoopStart(I))
722         Starts.push_back(&I);
723       else if (I.getOpcode() == ARM::t2LoopDec)
724         Decs.push_back(&I);
725       else if (I.getOpcode() == ARM::t2LoopEnd)
726         Ends.push_back(&I);
727     }
728 
729     if (Starts.empty() && Decs.empty() && Ends.empty())
730       continue;
731 
732     Changed = true;
733 
734     for (auto *Start : Starts) {
735       if (Start->getOpcode() == ARM::t2WhileLoopStart)
736         RevertWhile(Start);
737       else
738         Start->eraseFromParent();
739     }
740     for (auto *Dec : Decs)
741       RevertLoopDec(Dec);
742 
743     for (auto *End : Ends)
744       RevertLoopEnd(End);
745   }
746   return Changed;
747 }
748 
749 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
750   return new ARMLowOverheadLoops();
751 }
752