xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 8978c12b39f90194bb35860729ddca5e819f3b92)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/CodeGen/MachineFunctionPass.h"
26 #include "llvm/CodeGen/MachineLoopInfo.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/MC/MCInstrDesc.h"
29 
30 using namespace llvm;
31 
32 #define DEBUG_TYPE "arm-low-overhead-loops"
33 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
34 
35 namespace {
36 
37   struct LowOverheadLoop {
38 
39     MachineLoop *ML = nullptr;
40     MachineFunction *MF = nullptr;
41     MachineInstr *InsertPt = nullptr;
42     MachineInstr *Start = nullptr;
43     MachineInstr *Dec = nullptr;
44     MachineInstr *End = nullptr;
45     MachineInstr *VCTP = nullptr;
46     SmallVector<MachineInstr*, 4> VPTUsers;
47     bool Revert = false;
48     bool FoundOneVCTP = false;
49     bool CannotTailPredicate = false;
50 
51     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
52       MF = ML->getHeader()->getParent();
53     }
54 
55     // For now, only support one vctp instruction. If we find multiple then
56     // we shouldn't perform tail predication.
57     void addVCTP(MachineInstr *MI) {
58       if (!VCTP) {
59         VCTP = MI;
60         FoundOneVCTP = true;
61       } else
62         FoundOneVCTP = false;
63     }
64 
65     // Check that nothing else is writing to VPR and record any insts
66     // reading the VPR.
67     void ScanForVPR(MachineInstr *MI) {
68       for (auto &MO : MI->operands()) {
69         if (!MO.isReg() || MO.getReg() != ARM::VPR)
70           continue;
71         if (MO.isUse())
72           VPTUsers.push_back(MI);
73         if (MO.isDef()) {
74           CannotTailPredicate = true;
75           break;
76         }
77       }
78     }
79 
80     // If this is an MVE instruction, check that we know how to use tail
81     // predication with it.
82     void CheckTPValidity(MachineInstr *MI) {
83       if (CannotTailPredicate)
84         return;
85 
86       const MCInstrDesc &MCID = MI->getDesc();
87       uint64_t Flags = MCID.TSFlags;
88       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
89         return;
90 
91       if ((Flags & ARMII::ValidForTailPredication) == 0) {
92         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
93         CannotTailPredicate = true;
94       }
95     }
96 
97     bool IsTailPredicationLegal() const {
98       // For now, let's keep things really simple and only support a single
99       // block for tail predication.
100       return !Revert && FoundAllComponents() && FoundOneVCTP &&
101              !CannotTailPredicate && ML->getNumBlocks() == 1;
102     }
103 
104     // Is it safe to define LR with DLS/WLS?
105     // LR can be defined if it is the operand to start, because it's the same
106     // value, or if it's going to be equivalent to the operand to Start.
107     MachineInstr *IsSafeToDefineLR();
108 
109     // Check the branch targets are within range and we satisfy our restructi
110     void CheckLegality(ARMBasicBlockUtils *BBUtils);
111 
112     bool FoundAllComponents() const {
113       return Start && Dec && End;
114     }
115 
116     void dump() const {
117       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
118       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
119       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
120       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
121       if (!FoundAllComponents())
122         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
123       else if (!(Start && Dec && End))
124         dbgs() << "ARM Loops: Failed to find all loop components.\n";
125     }
126   };
127 
128   class ARMLowOverheadLoops : public MachineFunctionPass {
129     MachineFunction           *MF = nullptr;
130     const ARMBaseInstrInfo    *TII = nullptr;
131     MachineRegisterInfo       *MRI = nullptr;
132     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
133 
134   public:
135     static char ID;
136 
137     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
138 
139     void getAnalysisUsage(AnalysisUsage &AU) const override {
140       AU.setPreservesCFG();
141       AU.addRequired<MachineLoopInfo>();
142       MachineFunctionPass::getAnalysisUsage(AU);
143     }
144 
145     bool runOnMachineFunction(MachineFunction &MF) override;
146 
147     MachineFunctionProperties getRequiredProperties() const override {
148       return MachineFunctionProperties().set(
149           MachineFunctionProperties::Property::NoVRegs);
150     }
151 
152     StringRef getPassName() const override {
153       return ARM_LOW_OVERHEAD_LOOPS_NAME;
154     }
155 
156   private:
157     bool ProcessLoop(MachineLoop *ML);
158 
159     bool RevertNonLoops();
160 
161     void RevertWhile(MachineInstr *MI) const;
162 
163     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
164 
165     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
166 
167     void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
168 
169     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
170 
171     void Expand(LowOverheadLoop &LoLoop);
172 
173   };
174 }
175 
176 char ARMLowOverheadLoops::ID = 0;
177 
178 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
179                 false, false)
180 
181 static bool IsLoopStart(MachineInstr &MI) {
182   return MI.getOpcode() == ARM::t2DoLoopStart ||
183          MI.getOpcode() == ARM::t2WhileLoopStart;
184 }
185 
186 template<typename T>
187 static MachineInstr* SearchForDef(MachineInstr *Begin, T End, unsigned Reg) {
188   for(auto &MI : make_range(T(Begin), End)) {
189     for (auto &MO : MI.operands()) {
190       if (!MO.isReg() || !MO.isDef() || MO.getReg() != Reg)
191         continue;
192       return &MI;
193     }
194   }
195   return nullptr;
196 }
197 
198 static MachineInstr* SearchForUse(MachineInstr *Begin,
199                                   MachineBasicBlock::iterator End,
200                                   unsigned Reg) {
201   for(auto &MI : make_range(MachineBasicBlock::iterator(Begin), End)) {
202     for (auto &MO : MI.operands()) {
203       if (!MO.isReg() || !MO.isUse() || MO.getReg() != Reg)
204         continue;
205       return &MI;
206     }
207   }
208   return nullptr;
209 }
210 
211 static bool IsVCTP(MachineInstr *MI) {
212   switch (MI->getOpcode()) {
213   default:
214     break;
215   case ARM::MVE_VCTP8:
216   case ARM::MVE_VCTP16:
217   case ARM::MVE_VCTP32:
218   case ARM::MVE_VCTP64:
219     return true;
220   }
221   return false;
222 }
223 
224 MachineInstr *LowOverheadLoop::IsSafeToDefineLR() {
225 
226   auto IsMoveLR = [](MachineInstr *MI, unsigned Reg) {
227     return MI->getOpcode() == ARM::tMOVr &&
228            MI->getOperand(0).getReg() == ARM::LR &&
229            MI->getOperand(1).getReg() == Reg &&
230            MI->getOperand(2).getImm() == ARMCC::AL;
231    };
232 
233   MachineBasicBlock *MBB = Start->getParent();
234   unsigned CountReg = Start->getOperand(0).getReg();
235   // Walk forward and backward in the block to find the closest instructions
236   // that define LR. Then also filter them out if they're not a mov lr.
237   MachineInstr *PredLRDef = SearchForDef(Start, MBB->rend(), ARM::LR);
238   if (PredLRDef && !IsMoveLR(PredLRDef, CountReg))
239     PredLRDef = nullptr;
240 
241   MachineInstr *SuccLRDef = SearchForDef(Start, MBB->end(), ARM::LR);
242   if (SuccLRDef && !IsMoveLR(SuccLRDef, CountReg))
243     SuccLRDef = nullptr;
244 
245   // We've either found one, two or none mov lr instructions... Now figure out
246   // if they are performing the equilvant mov that the Start instruction will.
247   // Do this by scanning forward and backward to see if there's a def of the
248   // register holding the count value. If we find a suitable def, return it as
249   // the insert point. Later, if InsertPt != Start, then we can remove the
250   // redundant instruction.
251   if (SuccLRDef) {
252     MachineBasicBlock::iterator End(SuccLRDef);
253     if (!SearchForDef(Start, End, CountReg)) {
254       return SuccLRDef;
255     } else
256       SuccLRDef = nullptr;
257   }
258   if (PredLRDef) {
259     MachineBasicBlock::reverse_iterator End(PredLRDef);
260     if (!SearchForDef(Start, End, CountReg)) {
261       return PredLRDef;
262     } else
263       PredLRDef = nullptr;
264   }
265 
266   // We can define LR because LR already contains the same value.
267   if (Start->getOperand(0).getReg() == ARM::LR)
268     return Start;
269 
270   // We've found no suitable LR def and Start doesn't use LR directly. Can we
271   // just define LR anyway?
272   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
273   LivePhysRegs LiveRegs(*TRI);
274   LiveRegs.addLiveOuts(*MBB);
275 
276   // Not if we've haven't found a suitable mov and LR is live out.
277   if (LiveRegs.contains(ARM::LR))
278     return nullptr;
279 
280   // If LR is not live out, we can insert the instruction if nothing else
281   // uses LR after it.
282   if (!SearchForUse(Start, MBB->end(), ARM::LR))
283     return Start;
284 
285   LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find suitable insertion point for"
286              << " LR\n");
287   return nullptr;
288 }
289 
290 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
291   if (Revert)
292     return;
293 
294   if (!End->getOperand(1).isMBB())
295     report_fatal_error("Expected LoopEnd to target basic block");
296 
297   // TODO Maybe there's cases where the target doesn't have to be the header,
298   // but for now be safe and revert.
299   if (End->getOperand(1).getMBB() != ML->getHeader()) {
300     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
301     Revert = true;
302     return;
303   }
304 
305   // The WLS and LE instructions have 12-bits for the label offset. WLS
306   // requires a positive offset, while LE uses negative.
307   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
308       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
309     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
310     Revert = true;
311     return;
312   }
313 
314   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
315       (BBUtils->getOffsetOf(Start) >
316        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
317        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
318     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
319     Revert = true;
320     return;
321   }
322 
323   InsertPt = Revert ? nullptr : IsSafeToDefineLR();
324   if (!InsertPt) {
325     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
326     Revert = true;
327   } else
328     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
329 
330   LLVM_DEBUG(if (IsTailPredicationLegal()) {
331                dbgs() << "ARM Loops: Will use tail predication to convert:\n";
332                for (auto *MI : VPTUsers)
333                  dbgs() << " - " << *MI;
334              });
335 }
336 
337 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
338   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
339   if (!ST.hasLOB())
340     return false;
341 
342   MF = &mf;
343   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
344 
345   auto &MLI = getAnalysis<MachineLoopInfo>();
346   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
347   MRI = &MF->getRegInfo();
348   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
349   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
350   BBUtils->computeAllBlockSizes();
351   BBUtils->adjustBBOffsetsAfter(&MF->front());
352 
353   bool Changed = false;
354   for (auto ML : MLI) {
355     if (!ML->getParentLoop())
356       Changed |= ProcessLoop(ML);
357   }
358   Changed |= RevertNonLoops();
359   return Changed;
360 }
361 
362 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
363 
364   bool Changed = false;
365 
366   // Process inner loops first.
367   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
368     Changed |= ProcessLoop(*I);
369 
370   LLVM_DEBUG(dbgs() << "ARM Loops: Processing " << *ML);
371 
372   // Search the given block for a loop start instruction. If one isn't found,
373   // and there's only one predecessor block, search that one too.
374   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
375     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
376     for (auto &MI : *MBB) {
377       if (IsLoopStart(MI))
378         return &MI;
379     }
380     if (MBB->pred_size() == 1)
381       return SearchForStart(*MBB->pred_begin());
382     return nullptr;
383   };
384 
385   LowOverheadLoop LoLoop(ML);
386   // Search the preheader for the start intrinsic, or look through the
387   // predecessors of the header to find exactly one set.iterations intrinsic.
388   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
389   // with potentially multiple set.loop.iterations, so we need to enable this.
390   if (auto *Preheader = ML->getLoopPreheader())
391     LoLoop.Start = SearchForStart(Preheader);
392   else {
393     LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find loop preheader!\n"
394                << " - Performing manual predecessor search.\n");
395     MachineBasicBlock *Pred = nullptr;
396     for (auto *MBB : ML->getHeader()->predecessors()) {
397       if (!ML->contains(MBB)) {
398         if (Pred) {
399           LLVM_DEBUG(dbgs() << " - Found multiple out-of-loop preds.\n");
400           LoLoop.Start = nullptr;
401           break;
402         }
403         Pred = MBB;
404         LoLoop.Start = SearchForStart(MBB);
405       }
406     }
407   }
408 
409   // Find the low-overhead loop components and decide whether or not to fall
410   // back to a normal loop. Also look for a vctp instructions and decide
411   // whether we can convert that predicate using tail predication.
412   for (auto *MBB : reverse(ML->getBlocks())) {
413     for (auto &MI : *MBB) {
414       if (MI.getOpcode() == ARM::t2LoopDec)
415         LoLoop.Dec = &MI;
416       else if (MI.getOpcode() == ARM::t2LoopEnd)
417         LoLoop.End = &MI;
418       else if (IsLoopStart(MI))
419         LoLoop.Start = &MI;
420       else if (IsVCTP(&MI))
421         LoLoop.addVCTP(&MI);
422       else if (MI.getDesc().isCall()) {
423         // TODO: Though the call will require LE to execute again, does this
424         // mean we should revert? Always executing LE hopefully should be
425         // faster than performing a sub,cmp,br or even subs,br.
426         LoLoop.Revert = true;
427         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
428       } else {
429         // Once we've found a vctp, record the users of vpr and check there's
430         // no more vpr defs.
431         if (LoLoop.FoundOneVCTP)
432           LoLoop.ScanForVPR(&MI);
433         // Check we know how to tail predicate any mve instructions.
434         LoLoop.CheckTPValidity(&MI);
435       }
436 
437       // We need to ensure that LR is not used or defined inbetween LoopDec and
438       // LoopEnd.
439       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
440         continue;
441 
442       // If we find that LR has been written or read between LoopDec and
443       // LoopEnd, expect that the decremented value is being used else where.
444       // Because this value isn't actually going to be produced until the
445       // latch, by LE, we would need to generate a real sub. The value is also
446       // likely to be copied/reloaded for use of LoopEnd - in which in case
447       // we'd need to perform an add because it gets subtracted again by LE!
448       // The other option is to then generate the other form of LE which doesn't
449       // perform the sub.
450       for (auto &MO : MI.operands()) {
451         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
452             MO.getReg() == ARM::LR) {
453           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
454           LoLoop.Revert = true;
455           break;
456         }
457       }
458     }
459   }
460 
461   LLVM_DEBUG(LoLoop.dump());
462   if (!LoLoop.FoundAllComponents())
463     return false;
464 
465   LoLoop.CheckLegality(BBUtils.get());
466   Expand(LoLoop);
467   return true;
468 }
469 
470 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
471 // beq that branches to the exit branch.
472 // TODO: We could also try to generate a cbz if the value in LR is also in
473 // another low register.
474 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
475   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
476   MachineBasicBlock *MBB = MI->getParent();
477   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
478                                     TII->get(ARM::t2CMPri));
479   MIB.add(MI->getOperand(0));
480   MIB.addImm(0);
481   MIB.addImm(ARMCC::AL);
482   MIB.addReg(ARM::NoRegister);
483 
484   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
485   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
486     ARM::tBcc : ARM::t2Bcc;
487 
488   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
489   MIB.add(MI->getOperand(1));   // branch target
490   MIB.addImm(ARMCC::EQ);        // condition code
491   MIB.addReg(ARM::CPSR);
492   MI->eraseFromParent();
493 }
494 
495 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
496                                         bool AllowFlags) const {
497   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
498   MachineBasicBlock *MBB = MI->getParent();
499 
500   // If nothing uses or defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
501   bool SetFlags = false;
502   if (AllowFlags) {
503     if (auto *Def = SearchForDef(MI, MBB->end(), ARM::CPSR)) {
504       if (!SearchForUse(MI, MBB->end(), ARM::CPSR) &&
505           Def->getOpcode() == ARM::t2LoopEnd)
506         SetFlags = true;
507     }
508   }
509 
510   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
511                                     TII->get(ARM::t2SUBri));
512   MIB.addDef(ARM::LR);
513   MIB.add(MI->getOperand(1));
514   MIB.add(MI->getOperand(2));
515   MIB.addImm(ARMCC::AL);
516   MIB.addReg(0);
517 
518   if (SetFlags) {
519     MIB.addReg(ARM::CPSR);
520     MIB->getOperand(5).setIsDef(true);
521   } else
522     MIB.addReg(0);
523 
524   MI->eraseFromParent();
525   return SetFlags;
526 }
527 
528 // Generate a subs, or sub and cmp, and a branch instead of an LE.
529 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
530   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
531 
532   MachineBasicBlock *MBB = MI->getParent();
533   // Create cmp
534   if (!SkipCmp) {
535     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
536                                       TII->get(ARM::t2CMPri));
537     MIB.addReg(ARM::LR);
538     MIB.addImm(0);
539     MIB.addImm(ARMCC::AL);
540     MIB.addReg(ARM::NoRegister);
541   }
542 
543   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
544   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
545     ARM::tBcc : ARM::t2Bcc;
546 
547   // Create bne
548   MachineInstrBuilder MIB =
549     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
550   MIB.add(MI->getOperand(1));   // branch target
551   MIB.addImm(ARMCC::NE);        // condition code
552   MIB.addReg(ARM::CPSR);
553   MI->eraseFromParent();
554 }
555 
556 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
557   MachineInstr *InsertPt = LoLoop.InsertPt;
558   MachineInstr *Start = LoLoop.Start;
559   MachineBasicBlock *MBB = InsertPt->getParent();
560   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
561   unsigned Opc = 0;
562 
563   if (!LoLoop.IsTailPredicationLegal())
564     Opc = IsDo ? ARM::t2DLS : ARM::t2WLS;
565   else {
566     switch (LoLoop.VCTP->getOpcode()) {
567     case ARM::MVE_VCTP8:
568       Opc = IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8;
569       break;
570     case ARM::MVE_VCTP16:
571       Opc = IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16;
572       break;
573     case ARM::MVE_VCTP32:
574       Opc = IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32;
575       break;
576     case ARM::MVE_VCTP64:
577       Opc = IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64;
578       break;
579     }
580   }
581 
582   MachineInstrBuilder MIB =
583     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
584 
585   MIB.addDef(ARM::LR);
586   MIB.add(Start->getOperand(0));
587   if (!IsDo)
588     MIB.add(Start->getOperand(1));
589 
590   if (InsertPt != Start)
591     InsertPt->eraseFromParent();
592   Start->eraseFromParent();
593   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
594   return &*MIB;
595 }
596 
597 void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
598   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
599   LoLoop.VCTP->eraseFromParent();
600 
601   for (auto *MI : LoLoop.VPTUsers) {
602     if (MI->getOpcode() == ARM::MVE_VPST) {
603       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
604       MI->eraseFromParent();
605     } else {
606       unsigned OpNum = MI->getNumOperands() - 1;
607       assert((MI->getOperand(OpNum).isReg() &&
608               MI->getOperand(OpNum).getReg() == ARM::VPR) &&
609              "Expected VPR");
610       assert((MI->getOperand(OpNum-1).isImm() &&
611               MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
612              "Expected Then predicate");
613       MI->getOperand(OpNum-1).setImm(ARMVCC::None);
614       MI->getOperand(OpNum).setReg(0);
615       LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
616     }
617   }
618 }
619 
620 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
621 
622   // Combine the LoopDec and LoopEnd instructions into LE(TP).
623   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
624     MachineInstr *End = LoLoop.End;
625     MachineBasicBlock *MBB = End->getParent();
626     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
627       ARM::MVE_LETP : ARM::t2LEUpdate;
628     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
629                                       TII->get(Opc));
630     MIB.addDef(ARM::LR);
631     MIB.add(End->getOperand(0));
632     MIB.add(End->getOperand(1));
633     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
634 
635     LoLoop.End->eraseFromParent();
636     LoLoop.Dec->eraseFromParent();
637     return &*MIB;
638   };
639 
640   // TODO: We should be able to automatically remove these branches before we
641   // get here - probably by teaching analyzeBranch about the pseudo
642   // instructions.
643   // If there is an unconditional branch, after I, that just branches to the
644   // next block, remove it.
645   auto RemoveDeadBranch = [](MachineInstr *I) {
646     MachineBasicBlock *BB = I->getParent();
647     MachineInstr *Terminator = &BB->instr_back();
648     if (Terminator->isUnconditionalBranch() && I != Terminator) {
649       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
650       if (BB->isLayoutSuccessor(Succ)) {
651         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
652         Terminator->eraseFromParent();
653       }
654     }
655   };
656 
657   if (LoLoop.Revert) {
658     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
659       RevertWhile(LoLoop.Start);
660     else
661       LoLoop.Start->eraseFromParent();
662     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
663     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
664   } else {
665     LoLoop.Start = ExpandLoopStart(LoLoop);
666     RemoveDeadBranch(LoLoop.Start);
667     LoLoop.End = ExpandLoopEnd(LoLoop);
668     RemoveDeadBranch(LoLoop.End);
669     if (LoLoop.IsTailPredicationLegal())
670       RemoveVPTBlocks(LoLoop);
671   }
672 }
673 
674 bool ARMLowOverheadLoops::RevertNonLoops() {
675   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
676   bool Changed = false;
677 
678   for (auto &MBB : *MF) {
679     SmallVector<MachineInstr*, 4> Starts;
680     SmallVector<MachineInstr*, 4> Decs;
681     SmallVector<MachineInstr*, 4> Ends;
682 
683     for (auto &I : MBB) {
684       if (IsLoopStart(I))
685         Starts.push_back(&I);
686       else if (I.getOpcode() == ARM::t2LoopDec)
687         Decs.push_back(&I);
688       else if (I.getOpcode() == ARM::t2LoopEnd)
689         Ends.push_back(&I);
690     }
691 
692     if (Starts.empty() && Decs.empty() && Ends.empty())
693       continue;
694 
695     Changed = true;
696 
697     for (auto *Start : Starts) {
698       if (Start->getOpcode() == ARM::t2WhileLoopStart)
699         RevertWhile(Start);
700       else
701         Start->eraseFromParent();
702     }
703     for (auto *Dec : Decs)
704       RevertLoopDec(Dec);
705 
706     for (auto *End : Ends)
707       RevertLoopEnd(End);
708   }
709   return Changed;
710 }
711 
712 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
713   return new ARMLowOverheadLoops();
714 }
715