xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision acbc9aed726d4b7428691e026a214cb26ee2cf94)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 //===----------------------------------------------------------------------===//
19 
20 #include "ARM.h"
21 #include "ARMBaseInstrInfo.h"
22 #include "ARMBaseRegisterInfo.h"
23 #include "ARMBasicBlockInfo.h"
24 #include "ARMSubtarget.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/CodeGen/MachineFunctionPass.h"
28 #include "llvm/CodeGen/MachineLoopInfo.h"
29 #include "llvm/CodeGen/MachineLoopUtils.h"
30 #include "llvm/CodeGen/MachineRegisterInfo.h"
31 #include "llvm/CodeGen/Passes.h"
32 #include "llvm/CodeGen/ReachingDefAnalysis.h"
33 #include "llvm/MC/MCInstrDesc.h"
34 
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "arm-low-overhead-loops"
38 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
39 
40 namespace {
41 
42   struct PredicatedMI {
43     MachineInstr *MI = nullptr;
44     SetVector<MachineInstr*> Predicates;
45 
46   public:
47     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
48     MI(I) {
49       Predicates.insert(Preds.begin(), Preds.end());
50     }
51   };
52 
53   // Represent a VPT block, a list of instructions that begins with a VPST and
54   // has a maximum of four proceeding instructions. All instructions within the
55   // block are predicated upon the vpr and we allow instructions to define the
56   // vpr within in the block too.
57   class VPTBlock {
58     std::unique_ptr<PredicatedMI> VPST;
59     PredicatedMI *Divergent = nullptr;
60     SmallVector<PredicatedMI, 4> Insts;
61 
62   public:
63     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
64       VPST = std::make_unique<PredicatedMI>(MI, Preds);
65     }
66 
67     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
68       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
69       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
70         Divergent = &Insts.back();
71         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
72       }
73       Insts.emplace_back(MI, Preds);
74       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
75     }
76 
77     // Have we found an instruction within the block which defines the vpr? If
78     // so, not all the instructions in the block will have the same predicate.
79     bool HasNonUniformPredicate() const {
80       return Divergent != nullptr;
81     }
82 
83     // Is the given instruction part of the predicate set controlling the entry
84     // to the block.
85     bool IsPredicatedOn(MachineInstr *MI) const {
86       return VPST->Predicates.count(MI);
87     }
88 
89     // Is the given instruction the only predicate which controls the entry to
90     // the block.
91     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
92       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
93     }
94 
95     unsigned size() const { return Insts.size(); }
96     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
97     MachineInstr *getVPST() const { return VPST->MI; }
98     PredicatedMI *getDivergent() const { return Divergent; }
99   };
100 
101   struct LowOverheadLoop {
102 
103     MachineLoop *ML = nullptr;
104     MachineFunction *MF = nullptr;
105     MachineInstr *InsertPt = nullptr;
106     MachineInstr *Start = nullptr;
107     MachineInstr *Dec = nullptr;
108     MachineInstr *End = nullptr;
109     MachineInstr *VCTP = nullptr;
110     VPTBlock *CurrentBlock = nullptr;
111     SetVector<MachineInstr*> CurrentPredicate;
112     SmallVector<VPTBlock, 4> VPTBlocks;
113     bool Revert = false;
114     bool CannotTailPredicate = false;
115 
116     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
117       MF = ML->getHeader()->getParent();
118     }
119 
120     bool RecordVPTBlocks(MachineInstr *MI);
121 
122     // If this is an MVE instruction, check that we know how to use tail
123     // predication with it.
124     void CheckTPValidity(MachineInstr *MI) {
125       if (CannotTailPredicate)
126         return;
127 
128       if (!RecordVPTBlocks(MI)) {
129         CannotTailPredicate = true;
130         return;
131       }
132 
133       const MCInstrDesc &MCID = MI->getDesc();
134       uint64_t Flags = MCID.TSFlags;
135       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
136         return;
137 
138       if ((Flags & ARMII::ValidForTailPredication) == 0) {
139         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
140         CannotTailPredicate = true;
141       }
142     }
143 
144     bool IsTailPredicationLegal() const {
145       // For now, let's keep things really simple and only support a single
146       // block for tail predication.
147       return !Revert && FoundAllComponents() && VCTP &&
148              !CannotTailPredicate && ML->getNumBlocks() == 1;
149     }
150 
151     // Is it safe to define LR with DLS/WLS?
152     // LR can be defined if it is the operand to start, because it's the same
153     // value, or if it's going to be equivalent to the operand to Start.
154     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
155 
156     // Check the branch targets are within range and we satisfy our
157     // restrictions.
158     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
159                        MachineLoopInfo *MLI);
160 
161     bool FoundAllComponents() const {
162       return Start && Dec && End;
163     }
164 
165     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
166 
167     // Return the loop iteration count, or the number of elements if we're tail
168     // predicating.
169     MachineOperand &getCount() {
170       return IsTailPredicationLegal() ?
171         VCTP->getOperand(1) : Start->getOperand(0);
172     }
173 
174     unsigned getStartOpcode() const {
175       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
176       if (!IsTailPredicationLegal())
177         return IsDo ? ARM::t2DLS : ARM::t2WLS;
178 
179       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
180     }
181 
182     void dump() const {
183       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
184       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
185       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
186       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
187       if (!FoundAllComponents())
188         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
189       else if (!(Start && Dec && End))
190         dbgs() << "ARM Loops: Failed to find all loop components.\n";
191     }
192   };
193 
194   class ARMLowOverheadLoops : public MachineFunctionPass {
195     MachineFunction           *MF = nullptr;
196     MachineLoopInfo           *MLI = nullptr;
197     ReachingDefAnalysis       *RDA = nullptr;
198     const ARMBaseInstrInfo    *TII = nullptr;
199     MachineRegisterInfo       *MRI = nullptr;
200     const TargetRegisterInfo  *TRI = nullptr;
201     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
202 
203   public:
204     static char ID;
205 
206     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
207 
208     void getAnalysisUsage(AnalysisUsage &AU) const override {
209       AU.setPreservesCFG();
210       AU.addRequired<MachineLoopInfo>();
211       AU.addRequired<ReachingDefAnalysis>();
212       MachineFunctionPass::getAnalysisUsage(AU);
213     }
214 
215     bool runOnMachineFunction(MachineFunction &MF) override;
216 
217     MachineFunctionProperties getRequiredProperties() const override {
218       return MachineFunctionProperties().set(
219           MachineFunctionProperties::Property::NoVRegs).set(
220           MachineFunctionProperties::Property::TracksLiveness);
221     }
222 
223     StringRef getPassName() const override {
224       return ARM_LOW_OVERHEAD_LOOPS_NAME;
225     }
226 
227   private:
228     bool ProcessLoop(MachineLoop *ML);
229 
230     bool RevertNonLoops();
231 
232     void RevertWhile(MachineInstr *MI) const;
233 
234     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
235 
236     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
237 
238     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
239 
240     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
241 
242     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
243 
244     void Expand(LowOverheadLoop &LoLoop);
245 
246   };
247 }
248 
249 char ARMLowOverheadLoops::ID = 0;
250 
251 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
252                 false, false)
253 
254 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
255   // We can define LR because LR already contains the same value.
256   if (Start->getOperand(0).getReg() == ARM::LR)
257     return Start;
258 
259   unsigned CountReg = Start->getOperand(0).getReg();
260   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
261     return MI->getOpcode() == ARM::tMOVr &&
262            MI->getOperand(0).getReg() == ARM::LR &&
263            MI->getOperand(1).getReg() == CountReg &&
264            MI->getOperand(2).getImm() == ARMCC::AL;
265    };
266 
267   MachineBasicBlock *MBB = Start->getParent();
268 
269   // Find an insertion point:
270   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
271   //   to Count before Start, we can insert at that mov.
272   if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR))
273     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
274       return LRDef;
275 
276   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
277   //   to Count after Start, we can insert at that mov.
278   if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR))
279     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
280       return LRDef;
281 
282   // We've found no suitable LR def and Start doesn't use LR directly. Can we
283   // just define LR anyway?
284   if (!RDA->isRegUsedAfter(Start, ARM::LR))
285     return Start;
286 
287   return nullptr;
288 }
289 
290 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must
291 // not define a register that is used by any instructions, after and including,
292 // 'To'. These instructions also must not redefine any of Froms operands.
293 template<typename Iterator>
294 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) {
295   SmallSet<int, 2> Defs;
296   // First check that From would compute the same value if moved.
297   for (auto &MO : From->operands()) {
298     if (!MO.isReg() || MO.isUndef() || !MO.getReg())
299       continue;
300     if (MO.isDef())
301       Defs.insert(MO.getReg());
302     else if (!RDA->hasSameReachingDef(From, To, MO.getReg()))
303       return false;
304   }
305 
306   // Now walk checking that the rest of the instructions will compute the same
307   // value.
308   for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) {
309     for (auto &MO : I->operands())
310       if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg()))
311         return false;
312   }
313   return true;
314 }
315 
316 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
317                                     ReachingDefAnalysis *RDA,
318                                     MachineLoopInfo *MLI) {
319   if (Revert)
320     return;
321 
322   if (!End->getOperand(1).isMBB())
323     report_fatal_error("Expected LoopEnd to target basic block");
324 
325   // TODO Maybe there's cases where the target doesn't have to be the header,
326   // but for now be safe and revert.
327   if (End->getOperand(1).getMBB() != ML->getHeader()) {
328     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
329     Revert = true;
330     return;
331   }
332 
333   // The WLS and LE instructions have 12-bits for the label offset. WLS
334   // requires a positive offset, while LE uses negative.
335   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
336       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
337     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
338     Revert = true;
339     return;
340   }
341 
342   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
343       (BBUtils->getOffsetOf(Start) >
344        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
345        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
346     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
347     Revert = true;
348     return;
349   }
350 
351   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
352   if (!InsertPt) {
353     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
354     Revert = true;
355     return;
356   } else
357     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
358 
359   if (!IsTailPredicationLegal()) {
360     LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n");
361     return;
362   }
363 
364   // All predication within the loop should be based on vctp. If the block
365   // isn't predicated on entry, check whether the vctp is within the block
366   // and that all other instructions are then predicated on it.
367   for (auto &Block : VPTBlocks) {
368     if (Block.IsPredicatedOn(VCTP))
369       continue;
370     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
371       CannotTailPredicate = true;
372       return;
373     }
374     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
375     for (auto &PredMI : Insts) {
376       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
377         continue;
378       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
379                         << " - which is predicated on:\n";
380                         for (auto *MI : PredMI.Predicates)
381                           dbgs() << "   - " << *MI;
382                  );
383       CannotTailPredicate = true;
384       return;
385     }
386   }
387 
388   // For tail predication, we need to provide the number of elements, instead
389   // of the iteration count, to the loop start instruction. The number of
390   // elements is provided to the vctp instruction, so we need to check that
391   // we can use this register at InsertPt.
392   Register NumElements = VCTP->getOperand(1).getReg();
393 
394   // If the register is defined within loop, then we can't perform TP.
395   // TODO: Check whether this is just a mov of a register that would be
396   // available.
397   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
398     CannotTailPredicate = true;
399     return;
400   }
401 
402   // The element count register maybe defined after InsertPt, in which case we
403   // need to try to move either InsertPt or the def so that the [w|d]lstp can
404   // use the value.
405   MachineBasicBlock *InsertBB = InsertPt->getParent();
406   if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) {
407     if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) {
408       if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) {
409         ElemDef->removeFromParent();
410         InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef);
411         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
412                    << *ElemDef);
413       } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) {
414         InsertPt->removeFromParent();
415         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt);
416         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
417       } else {
418         CannotTailPredicate = true;
419         return;
420       }
421     }
422   }
423 
424   // Especially in the case of while loops, InsertBB may not be the
425   // preheader, so we need to check that the register isn't redefined
426   // before entering the loop.
427   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
428                                       Register NumElements) {
429     // NumElements is redefined in this block.
430     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
431       return true;
432 
433     // Don't continue searching up through multiple predecessors.
434     if (MBB->pred_size() > 1)
435       return true;
436 
437     return false;
438   };
439 
440   // First, find the block that looks like the preheader.
441   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
442   if (!MBB) {
443     CannotTailPredicate = true;
444     return;
445   }
446 
447   // Then search backwards for a def, until we get to InsertBB.
448   while (MBB != InsertBB) {
449     CannotTailPredicate = CannotProvideElements(MBB, NumElements);
450     if (CannotTailPredicate)
451       return;
452     MBB = *MBB->pred_begin();
453   }
454 
455   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
456 }
457 
458 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) {
459   // Only support a single vctp.
460   if (isVCTP(MI) && VCTP)
461     return false;
462 
463   // Start a new vpt block when we discover a vpt.
464   if (MI->getOpcode() == ARM::MVE_VPST) {
465     VPTBlocks.emplace_back(MI, CurrentPredicate);
466     CurrentBlock = &VPTBlocks.back();
467     return true;
468   }
469 
470   if (isVCTP(MI))
471     VCTP = MI;
472 
473   unsigned VPROpNum = MI->getNumOperands() - 1;
474   bool IsUse = false;
475   if (MI->getOperand(VPROpNum).isReg() &&
476       MI->getOperand(VPROpNum).getReg() == ARM::VPR &&
477       MI->getOperand(VPROpNum).isUse()) {
478     // If this instruction is predicated by VPR, it will be its last
479     // operand.  Also check that it's only 'Then' predicated.
480     if (!MI->getOperand(VPROpNum-1).isImm() ||
481         MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) {
482       LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: "
483                  << *MI);
484       return false;
485     }
486     CurrentBlock->addInst(MI, CurrentPredicate);
487     IsUse = true;
488   }
489 
490   bool IsDef = false;
491   for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) {
492     const MachineOperand &MO = MI->getOperand(i);
493     if (!MO.isReg() || MO.getReg() != ARM::VPR)
494       continue;
495 
496     if (MO.isDef()) {
497       CurrentPredicate.insert(MI);
498       IsDef = true;
499     } else {
500       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
501       return false;
502     }
503   }
504 
505   // If we find a vpr def that is not already predicated on the vctp, we've
506   // got disjoint predicates that may not be equivalent when we do the
507   // conversion.
508   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
509     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
510     return false;
511   }
512 
513   return true;
514 }
515 
516 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
517   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
518   if (!ST.hasLOB())
519     return false;
520 
521   MF = &mf;
522   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
523 
524   MLI = &getAnalysis<MachineLoopInfo>();
525   RDA = &getAnalysis<ReachingDefAnalysis>();
526   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
527   MRI = &MF->getRegInfo();
528   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
529   TRI = ST.getRegisterInfo();
530   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
531   BBUtils->computeAllBlockSizes();
532   BBUtils->adjustBBOffsetsAfter(&MF->front());
533 
534   bool Changed = false;
535   for (auto ML : *MLI) {
536     if (!ML->getParentLoop())
537       Changed |= ProcessLoop(ML);
538   }
539   Changed |= RevertNonLoops();
540   return Changed;
541 }
542 
543 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
544 
545   bool Changed = false;
546 
547   // Process inner loops first.
548   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
549     Changed |= ProcessLoop(*I);
550 
551   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
552              if (auto *Preheader = ML->getLoopPreheader())
553                dbgs() << " - " << Preheader->getName() << "\n";
554              else if (auto *Preheader = MLI->findLoopPreheader(ML))
555                dbgs() << " - " << Preheader->getName() << "\n";
556              for (auto *MBB : ML->getBlocks())
557                dbgs() << " - " << MBB->getName() << "\n";
558             );
559 
560   // Search the given block for a loop start instruction. If one isn't found,
561   // and there's only one predecessor block, search that one too.
562   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
563     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
564     for (auto &MI : *MBB) {
565       if (isLoopStart(MI))
566         return &MI;
567     }
568     if (MBB->pred_size() == 1)
569       return SearchForStart(*MBB->pred_begin());
570     return nullptr;
571   };
572 
573   LowOverheadLoop LoLoop(ML);
574   // Search the preheader for the start intrinsic.
575   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
576   // with potentially multiple set.loop.iterations, so we need to enable this.
577   if (auto *Preheader = ML->getLoopPreheader())
578     LoLoop.Start = SearchForStart(Preheader);
579   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
580     LoLoop.Start = SearchForStart(Preheader);
581   else
582     return false;
583 
584   // Find the low-overhead loop components and decide whether or not to fall
585   // back to a normal loop. Also look for a vctp instructions and decide
586   // whether we can convert that predicate using tail predication.
587   for (auto *MBB : reverse(ML->getBlocks())) {
588     for (auto &MI : *MBB) {
589       if (MI.getOpcode() == ARM::t2LoopDec)
590         LoLoop.Dec = &MI;
591       else if (MI.getOpcode() == ARM::t2LoopEnd)
592         LoLoop.End = &MI;
593       else if (isLoopStart(MI))
594         LoLoop.Start = &MI;
595       else if (MI.getDesc().isCall()) {
596         // TODO: Though the call will require LE to execute again, does this
597         // mean we should revert? Always executing LE hopefully should be
598         // faster than performing a sub,cmp,br or even subs,br.
599         LoLoop.Revert = true;
600         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
601       } else {
602         // Record VPR defs and build up their corresponding vpt blocks.
603         // Check we know how to tail predicate any mve instructions.
604         LoLoop.CheckTPValidity(&MI);
605       }
606 
607       // We need to ensure that LR is not used or defined inbetween LoopDec and
608       // LoopEnd.
609       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
610         continue;
611 
612       // If we find that LR has been written or read between LoopDec and
613       // LoopEnd, expect that the decremented value is being used else where.
614       // Because this value isn't actually going to be produced until the
615       // latch, by LE, we would need to generate a real sub. The value is also
616       // likely to be copied/reloaded for use of LoopEnd - in which in case
617       // we'd need to perform an add because it gets subtracted again by LE!
618       // The other option is to then generate the other form of LE which doesn't
619       // perform the sub.
620       for (auto &MO : MI.operands()) {
621         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
622             MO.getReg() == ARM::LR) {
623           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
624           LoLoop.Revert = true;
625           break;
626         }
627       }
628     }
629   }
630 
631   LLVM_DEBUG(LoLoop.dump());
632   if (!LoLoop.FoundAllComponents())
633     return false;
634 
635   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
636   Expand(LoLoop);
637   return true;
638 }
639 
640 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
641 // beq that branches to the exit branch.
642 // TODO: We could also try to generate a cbz if the value in LR is also in
643 // another low register.
644 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
645   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
646   MachineBasicBlock *MBB = MI->getParent();
647   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
648                                     TII->get(ARM::t2CMPri));
649   MIB.add(MI->getOperand(0));
650   MIB.addImm(0);
651   MIB.addImm(ARMCC::AL);
652   MIB.addReg(ARM::NoRegister);
653 
654   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
655   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
656     ARM::tBcc : ARM::t2Bcc;
657 
658   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
659   MIB.add(MI->getOperand(1));   // branch target
660   MIB.addImm(ARMCC::EQ);        // condition code
661   MIB.addReg(ARM::CPSR);
662   MI->eraseFromParent();
663 }
664 
665 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
666                                         bool SetFlags) const {
667   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
668   MachineBasicBlock *MBB = MI->getParent();
669 
670   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
671   if (SetFlags &&
672       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
673        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
674       SetFlags = false;
675 
676   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
677                                     TII->get(ARM::t2SUBri));
678   MIB.addDef(ARM::LR);
679   MIB.add(MI->getOperand(1));
680   MIB.add(MI->getOperand(2));
681   MIB.addImm(ARMCC::AL);
682   MIB.addReg(0);
683 
684   if (SetFlags) {
685     MIB.addReg(ARM::CPSR);
686     MIB->getOperand(5).setIsDef(true);
687   } else
688     MIB.addReg(0);
689 
690   MI->eraseFromParent();
691   return SetFlags;
692 }
693 
694 // Generate a subs, or sub and cmp, and a branch instead of an LE.
695 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
696   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
697 
698   MachineBasicBlock *MBB = MI->getParent();
699   // Create cmp
700   if (!SkipCmp) {
701     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
702                                       TII->get(ARM::t2CMPri));
703     MIB.addReg(ARM::LR);
704     MIB.addImm(0);
705     MIB.addImm(ARMCC::AL);
706     MIB.addReg(ARM::NoRegister);
707   }
708 
709   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
710   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
711     ARM::tBcc : ARM::t2Bcc;
712 
713   // Create bne
714   MachineInstrBuilder MIB =
715     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
716   MIB.add(MI->getOperand(1));   // branch target
717   MIB.addImm(ARMCC::NE);        // condition code
718   MIB.addReg(ARM::CPSR);
719   MI->eraseFromParent();
720 }
721 
722 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
723   MachineInstr *InsertPt = LoLoop.InsertPt;
724   MachineInstr *Start = LoLoop.Start;
725   MachineBasicBlock *MBB = InsertPt->getParent();
726   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
727   unsigned Opc = LoLoop.getStartOpcode();
728   MachineOperand &Count = LoLoop.getCount();
729 
730   MachineInstrBuilder MIB =
731     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
732 
733   MIB.addDef(ARM::LR);
734   MIB.add(Count);
735   if (!IsDo)
736     MIB.add(Start->getOperand(1));
737 
738   // When using tail-predication, try to delete the dead code that was used to
739   // calculate the number of loop iterations.
740   if (LoLoop.IsTailPredicationLegal()) {
741     SmallVector<MachineInstr*, 4> Killed;
742     SmallVector<MachineInstr*, 4> Dead;
743     if (auto *Def = RDA->getReachingMIDef(Start,
744                                           Start->getOperand(0).getReg())) {
745       Killed.push_back(Def);
746 
747       while (!Killed.empty()) {
748         MachineInstr *Def = Killed.back();
749         Killed.pop_back();
750         Dead.push_back(Def);
751         for (auto &MO : Def->operands()) {
752           if (!MO.isReg() || !MO.isKill())
753             continue;
754 
755           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
756           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
757             Killed.push_back(Kill);
758         }
759       }
760       for (auto *MI : Dead)
761         MI->eraseFromParent();
762     }
763   }
764 
765   // If we're inserting at a mov lr, then remove it as it's redundant.
766   if (InsertPt != Start)
767     InsertPt->eraseFromParent();
768   Start->eraseFromParent();
769   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
770   return &*MIB;
771 }
772 
773 // Goal is to optimise and clean-up these loops:
774 //
775 //   vector.body:
776 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
777 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
778 //     ..
779 //     $lr = MVE_DLSTP_32 renamable $r3
780 //
781 // The SUB is the old update of the loop iteration count expression, which
782 // is no longer needed. This sub is removed when the element count, which is in
783 // r3 in this example, is defined by an instruction in the loop, and it has
784 // no uses.
785 //
786 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
787   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
788   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
789 
790   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
791 
792   if (LoLoop.ML->getNumBlocks() != 1) {
793     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
794     return;
795   }
796 
797   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
798              LoLoop.VCTP->getOperand(1).dump());
799 
800   // Find the definition we are interested in removing, if there is one.
801   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
802   if (!Def)
803     return;
804 
805   // Bail if we define CPSR and it is not dead
806   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
807     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
808     return;
809   }
810 
811   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
812   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
813     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
814     return;
815   }
816 
817   // Bail if there are uses after this Def in the block.
818   SmallVector<MachineInstr*, 4> Uses;
819   RDA->getReachingLocalUses(Def, ElemCount, Uses);
820   if (Uses.size()) {
821     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
822     return;
823   }
824 
825   Uses.clear();
826   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
827 
828   // Remove Def if there are no uses, or if the only use is the VCTP
829   // instruction.
830   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
831     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
832                Def->dump());
833     Def->eraseFromParent();
834   }
835 }
836 
837 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
838   auto RemovePredicate = [](MachineInstr *MI) {
839     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
840     unsigned OpNum = MI->getNumOperands() - 1;
841     assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then &&
842            "Expected Then predicate!");
843     MI->getOperand(OpNum-1).setImm(ARMVCC::None);
844     MI->getOperand(OpNum).setReg(0);
845   };
846 
847   // There are a few scenarios which we have to fix up:
848   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
849   //    defs.
850   // 2) A VPT block which is only predicated by the vctp but has an internal
851   //    vpr def.
852   // 3) A VPT block which is predicated upon the vctp as well as another vpr
853   //    def.
854   // 4) A VPT block which is not predicated upon a vctp, but contains it and
855   //    all instructions within the block are predicated upon in.
856 
857   for (auto &Block : LoLoop.getVPTBlocks()) {
858     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
859     if (Block.HasNonUniformPredicate()) {
860       PredicatedMI *Divergent = Block.getDivergent();
861       if (isVCTP(Divergent->MI)) {
862         // The vctp will be removed, so the size of the vpt block needs to be
863         // modified.
864         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
865         Block.getVPST()->getOperand(0).setImm(Size);
866         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
867       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
868         // The VPT block has a non-uniform predicate but it's entry is guarded
869         // only by a vctp, which means we:
870         // - Need to remove the original vpst.
871         // - Then need to unpredicate any following instructions, until
872         //   we come across the divergent vpr def.
873         // - Insert a new vpst to predicate the instruction(s) that following
874         //   the divergent vpr def.
875         // TODO: We could be producing more VPT blocks than necessary and could
876         // fold the newly created one into a proceeding one.
877         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
878              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
879           RemovePredicate(&*I);
880 
881         unsigned Size = 0;
882         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
883         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
884         MachineInstr *InsertAt = nullptr;
885         while (I != E) {
886           InsertAt = &*I;
887           ++Size;
888           ++I;
889         }
890         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
891                                           InsertAt->getDebugLoc(),
892                                           TII->get(ARM::MVE_VPST));
893         MIB.addImm(getARMVPTBlockMask(Size));
894         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
895         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
896         Block.getVPST()->eraseFromParent();
897       }
898     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
899       // A vpt block which is only predicated upon vctp and has no internal vpr
900       // defs:
901       // - Remove vpst.
902       // - Unpredicate the remaining instructions.
903       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
904       Block.getVPST()->eraseFromParent();
905       for (auto &PredMI : Insts)
906         RemovePredicate(PredMI.MI);
907     }
908   }
909 
910   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
911   LoLoop.VCTP->eraseFromParent();
912 }
913 
914 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
915 
916   // Combine the LoopDec and LoopEnd instructions into LE(TP).
917   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
918     MachineInstr *End = LoLoop.End;
919     MachineBasicBlock *MBB = End->getParent();
920     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
921       ARM::MVE_LETP : ARM::t2LEUpdate;
922     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
923                                       TII->get(Opc));
924     MIB.addDef(ARM::LR);
925     MIB.add(End->getOperand(0));
926     MIB.add(End->getOperand(1));
927     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
928 
929     LoLoop.End->eraseFromParent();
930     LoLoop.Dec->eraseFromParent();
931     return &*MIB;
932   };
933 
934   // TODO: We should be able to automatically remove these branches before we
935   // get here - probably by teaching analyzeBranch about the pseudo
936   // instructions.
937   // If there is an unconditional branch, after I, that just branches to the
938   // next block, remove it.
939   auto RemoveDeadBranch = [](MachineInstr *I) {
940     MachineBasicBlock *BB = I->getParent();
941     MachineInstr *Terminator = &BB->instr_back();
942     if (Terminator->isUnconditionalBranch() && I != Terminator) {
943       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
944       if (BB->isLayoutSuccessor(Succ)) {
945         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
946         Terminator->eraseFromParent();
947       }
948     }
949   };
950 
951   if (LoLoop.Revert) {
952     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
953       RevertWhile(LoLoop.Start);
954     else
955       LoLoop.Start->eraseFromParent();
956     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
957     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
958   } else {
959     LoLoop.Start = ExpandLoopStart(LoLoop);
960     RemoveDeadBranch(LoLoop.Start);
961     LoLoop.End = ExpandLoopEnd(LoLoop);
962     RemoveDeadBranch(LoLoop.End);
963     if (LoLoop.IsTailPredicationLegal()) {
964       RemoveLoopUpdate(LoLoop);
965       ConvertVPTBlocks(LoLoop);
966     }
967   }
968 }
969 
970 bool ARMLowOverheadLoops::RevertNonLoops() {
971   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
972   bool Changed = false;
973 
974   for (auto &MBB : *MF) {
975     SmallVector<MachineInstr*, 4> Starts;
976     SmallVector<MachineInstr*, 4> Decs;
977     SmallVector<MachineInstr*, 4> Ends;
978 
979     for (auto &I : MBB) {
980       if (isLoopStart(I))
981         Starts.push_back(&I);
982       else if (I.getOpcode() == ARM::t2LoopDec)
983         Decs.push_back(&I);
984       else if (I.getOpcode() == ARM::t2LoopEnd)
985         Ends.push_back(&I);
986     }
987 
988     if (Starts.empty() && Decs.empty() && Ends.empty())
989       continue;
990 
991     Changed = true;
992 
993     for (auto *Start : Starts) {
994       if (Start->getOpcode() == ARM::t2WhileLoopStart)
995         RevertWhile(Start);
996       else
997         Start->eraseFromParent();
998     }
999     for (auto *Dec : Decs)
1000       RevertLoopDec(Dec);
1001 
1002     for (auto *End : Ends)
1003       RevertLoopEnd(End);
1004   }
1005   return Changed;
1006 }
1007 
1008 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1009   return new ARMLowOverheadLoops();
1010 }
1011