xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision e93e0d413f3afa1df5c5f88df546bebcd1183155)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 //===----------------------------------------------------------------------===//
23 
24 #include "ARM.h"
25 #include "ARMBaseInstrInfo.h"
26 #include "ARMBaseRegisterInfo.h"
27 #include "ARMBasicBlockInfo.h"
28 #include "ARMSubtarget.h"
29 #include "llvm/ADT/SetOperations.h"
30 #include "llvm/ADT/SmallSet.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineLoopUtils.h"
34 #include "llvm/CodeGen/MachineRegisterInfo.h"
35 #include "llvm/CodeGen/Passes.h"
36 #include "llvm/CodeGen/ReachingDefAnalysis.h"
37 #include "llvm/MC/MCInstrDesc.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "arm-low-overhead-loops"
42 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
43 
44 namespace {
45 
46   class PostOrderLoopTraversal {
47     MachineLoop &ML;
48     MachineLoopInfo &MLI;
49     SmallPtrSet<MachineBasicBlock*, 4> Visited;
50     SmallVector<MachineBasicBlock*, 4> Order;
51 
52   public:
53     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
54       : ML(ML), MLI(MLI) { }
55 
56     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
57       return Order;
58     }
59 
60     // Visit all the blocks within the loop, as well as exit blocks and any
61     // blocks properly dominating the header.
62     void ProcessLoop() {
63       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
64         (MachineBasicBlock *MBB) -> void {
65         if (Visited.count(MBB))
66           return;
67 
68         Visited.insert(MBB);
69         for (auto *Succ : MBB->successors()) {
70           if (!ML.contains(Succ))
71             continue;
72           Search(Succ);
73         }
74         Order.push_back(MBB);
75       };
76 
77       // Insert exit blocks.
78       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
79       ML.getExitBlocks(ExitBlocks);
80       for (auto *MBB : ExitBlocks)
81         Order.push_back(MBB);
82 
83       // Then add the loop body.
84       Search(ML.getHeader());
85 
86       // Then try the preheader and its predecessors.
87       std::function<void(MachineBasicBlock*)> GetPredecessor =
88         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
89         Order.push_back(MBB);
90         if (MBB->pred_size() == 1)
91           GetPredecessor(*MBB->pred_begin());
92       };
93 
94       if (auto *Preheader = ML.getLoopPreheader())
95         GetPredecessor(Preheader);
96       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
97         GetPredecessor(Preheader);
98     }
99   };
100 
101   struct PredicatedMI {
102     MachineInstr *MI = nullptr;
103     SetVector<MachineInstr*> Predicates;
104 
105   public:
106     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
107     MI(I) {
108       Predicates.insert(Preds.begin(), Preds.end());
109     }
110   };
111 
112   // Represent a VPT block, a list of instructions that begins with a VPST and
113   // has a maximum of four proceeding instructions. All instructions within the
114   // block are predicated upon the vpr and we allow instructions to define the
115   // vpr within in the block too.
116   class VPTBlock {
117     std::unique_ptr<PredicatedMI> VPST;
118     PredicatedMI *Divergent = nullptr;
119     SmallVector<PredicatedMI, 4> Insts;
120 
121   public:
122     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
123       VPST = std::make_unique<PredicatedMI>(MI, Preds);
124     }
125 
126     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
127       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
128       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
129         Divergent = &Insts.back();
130         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
131       }
132       Insts.emplace_back(MI, Preds);
133       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
134     }
135 
136     // Have we found an instruction within the block which defines the vpr? If
137     // so, not all the instructions in the block will have the same predicate.
138     bool HasNonUniformPredicate() const {
139       return Divergent != nullptr;
140     }
141 
142     // Is the given instruction part of the predicate set controlling the entry
143     // to the block.
144     bool IsPredicatedOn(MachineInstr *MI) const {
145       return VPST->Predicates.count(MI);
146     }
147 
148     // Is the given instruction the only predicate which controls the entry to
149     // the block.
150     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
151       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
152     }
153 
154     unsigned size() const { return Insts.size(); }
155     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
156     MachineInstr *getVPST() const { return VPST->MI; }
157     PredicatedMI *getDivergent() const { return Divergent; }
158   };
159 
160   struct LowOverheadLoop {
161 
162     MachineLoop *ML = nullptr;
163     MachineFunction *MF = nullptr;
164     MachineInstr *InsertPt = nullptr;
165     MachineInstr *Start = nullptr;
166     MachineInstr *Dec = nullptr;
167     MachineInstr *End = nullptr;
168     MachineInstr *VCTP = nullptr;
169     VPTBlock *CurrentBlock = nullptr;
170     SetVector<MachineInstr*> CurrentPredicate;
171     SmallVector<VPTBlock, 4> VPTBlocks;
172     bool Revert = false;
173     bool CannotTailPredicate = false;
174 
175     LowOverheadLoop(MachineLoop *ML) : ML(ML) {
176       MF = ML->getHeader()->getParent();
177     }
178 
179     bool RecordVPTBlocks(MachineInstr *MI);
180 
181     // If this is an MVE instruction, check that we know how to use tail
182     // predication with it.
183     void AnalyseMVEInst(MachineInstr *MI) {
184       if (CannotTailPredicate)
185         return;
186 
187       if (!RecordVPTBlocks(MI)) {
188         CannotTailPredicate = true;
189         return;
190       }
191 
192       const MCInstrDesc &MCID = MI->getDesc();
193       uint64_t Flags = MCID.TSFlags;
194       if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
195         return;
196 
197       if ((Flags & ARMII::ValidForTailPredication) == 0) {
198         LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
199         CannotTailPredicate = true;
200       }
201     }
202 
203     bool IsTailPredicationLegal() const {
204       // For now, let's keep things really simple and only support a single
205       // block for tail predication.
206       return !Revert && FoundAllComponents() && VCTP &&
207              !CannotTailPredicate && ML->getNumBlocks() == 1;
208     }
209 
210     bool ValidateTailPredicate(MachineInstr *StartInsertPt,
211                                ReachingDefAnalysis *RDA,
212                                MachineLoopInfo *MLI);
213 
214     // Is it safe to define LR with DLS/WLS?
215     // LR can be defined if it is the operand to start, because it's the same
216     // value, or if it's going to be equivalent to the operand to Start.
217     MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA);
218 
219     // Check the branch targets are within range and we satisfy our
220     // restrictions.
221     void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA,
222                        MachineLoopInfo *MLI);
223 
224     bool FoundAllComponents() const {
225       return Start && Dec && End;
226     }
227 
228     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
229 
230     // Return the loop iteration count, or the number of elements if we're tail
231     // predicating.
232     MachineOperand &getCount() {
233       return IsTailPredicationLegal() ?
234         VCTP->getOperand(1) : Start->getOperand(0);
235     }
236 
237     unsigned getStartOpcode() const {
238       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
239       if (!IsTailPredicationLegal())
240         return IsDo ? ARM::t2DLS : ARM::t2WLS;
241 
242       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
243     }
244 
245     void dump() const {
246       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
247       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
248       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
249       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
250       if (!FoundAllComponents())
251         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
252       else if (!(Start && Dec && End))
253         dbgs() << "ARM Loops: Failed to find all loop components.\n";
254     }
255   };
256 
257   class ARMLowOverheadLoops : public MachineFunctionPass {
258     MachineFunction           *MF = nullptr;
259     MachineLoopInfo           *MLI = nullptr;
260     ReachingDefAnalysis       *RDA = nullptr;
261     const ARMBaseInstrInfo    *TII = nullptr;
262     MachineRegisterInfo       *MRI = nullptr;
263     const TargetRegisterInfo  *TRI = nullptr;
264     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
265 
266   public:
267     static char ID;
268 
269     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
270 
271     void getAnalysisUsage(AnalysisUsage &AU) const override {
272       AU.setPreservesCFG();
273       AU.addRequired<MachineLoopInfo>();
274       AU.addRequired<ReachingDefAnalysis>();
275       MachineFunctionPass::getAnalysisUsage(AU);
276     }
277 
278     bool runOnMachineFunction(MachineFunction &MF) override;
279 
280     MachineFunctionProperties getRequiredProperties() const override {
281       return MachineFunctionProperties().set(
282           MachineFunctionProperties::Property::NoVRegs).set(
283           MachineFunctionProperties::Property::TracksLiveness);
284     }
285 
286     StringRef getPassName() const override {
287       return ARM_LOW_OVERHEAD_LOOPS_NAME;
288     }
289 
290   private:
291     bool ProcessLoop(MachineLoop *ML);
292 
293     bool RevertNonLoops();
294 
295     void RevertWhile(MachineInstr *MI) const;
296 
297     bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const;
298 
299     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
300 
301     void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
302 
303     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
304 
305     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
306 
307     void Expand(LowOverheadLoop &LoLoop);
308 
309   };
310 }
311 
312 char ARMLowOverheadLoops::ID = 0;
313 
314 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
315                 false, false)
316 
317 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) {
318   // We can define LR because LR already contains the same value.
319   if (Start->getOperand(0).getReg() == ARM::LR)
320     return Start;
321 
322   unsigned CountReg = Start->getOperand(0).getReg();
323   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
324     return MI->getOpcode() == ARM::tMOVr &&
325            MI->getOperand(0).getReg() == ARM::LR &&
326            MI->getOperand(1).getReg() == CountReg &&
327            MI->getOperand(2).getImm() == ARMCC::AL;
328    };
329 
330   MachineBasicBlock *MBB = Start->getParent();
331 
332   // Find an insertion point:
333   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
334   //   to Count before Start, we can insert at that mov.
335   if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR))
336     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
337       return LRDef;
338 
339   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
340   //   to Count after Start, we can insert at that mov.
341   if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR))
342     if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg))
343       return LRDef;
344 
345   // We've found no suitable LR def and Start doesn't use LR directly. Can we
346   // just define LR anyway?
347   if (!RDA->isRegUsedAfter(Start, ARM::LR))
348     return Start;
349 
350   return nullptr;
351 }
352 
353 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must
354 // not define a register that is used by any instructions, after and including,
355 // 'To'. These instructions also must not redefine any of Froms operands.
356 template<typename Iterator>
357 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) {
358   SmallSet<int, 2> Defs;
359   // First check that From would compute the same value if moved.
360   for (auto &MO : From->operands()) {
361     if (!MO.isReg() || MO.isUndef() || !MO.getReg())
362       continue;
363     if (MO.isDef())
364       Defs.insert(MO.getReg());
365     else if (!RDA->hasSameReachingDef(From, To, MO.getReg()))
366       return false;
367   }
368 
369   // Now walk checking that the rest of the instructions will compute the same
370   // value.
371   for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) {
372     for (auto &MO : I->operands())
373       if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg()))
374         return false;
375   }
376   return true;
377 }
378 
379 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt,
380 		                            ReachingDefAnalysis *RDA,
381 		                            MachineLoopInfo *MLI) {
382   // All predication within the loop should be based on vctp. If the block
383   // isn't predicated on entry, check whether the vctp is within the block
384   // and that all other instructions are then predicated on it.
385   for (auto &Block : VPTBlocks) {
386     if (Block.IsPredicatedOn(VCTP))
387       continue;
388     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI))
389       return false;
390     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
391     for (auto &PredMI : Insts) {
392       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
393         continue;
394       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
395                         << " - which is predicated on:\n";
396                         for (auto *MI : PredMI.Predicates)
397                           dbgs() << "   - " << *MI;
398                  );
399       return false;
400     }
401   }
402 
403   // For tail predication, we need to provide the number of elements, instead
404   // of the iteration count, to the loop start instruction. The number of
405   // elements is provided to the vctp instruction, so we need to check that
406   // we can use this register at InsertPt.
407   Register NumElements = VCTP->getOperand(1).getReg();
408 
409   // If the register is defined within loop, then we can't perform TP.
410   // TODO: Check whether this is just a mov of a register that would be
411   // available.
412   if (RDA->getReachingDef(VCTP, NumElements) >= 0) {
413     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
414     return false;
415   }
416 
417   // The element count register maybe defined after InsertPt, in which case we
418   // need to try to move either InsertPt or the def so that the [w|d]lstp can
419   // use the value.
420   MachineBasicBlock *InsertBB = InsertPt->getParent();
421   if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) {
422     if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) {
423       if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) {
424         ElemDef->removeFromParent();
425         InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef);
426         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
427                    << *ElemDef);
428       } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) {
429         InsertPt->removeFromParent();
430         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt);
431         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
432       } else
433         return false;
434     }
435   }
436 
437   // Especially in the case of while loops, InsertBB may not be the
438   // preheader, so we need to check that the register isn't redefined
439   // before entering the loop.
440   auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB,
441                                       Register NumElements) {
442     // NumElements is redefined in this block.
443     if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0)
444       return true;
445 
446     // Don't continue searching up through multiple predecessors.
447     if (MBB->pred_size() > 1)
448       return true;
449 
450     return false;
451   };
452 
453   // First, find the block that looks like the preheader.
454   MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true);
455   if (!MBB)
456     return false;
457 
458   // Then search backwards for a def, until we get to InsertBB.
459   while (MBB != InsertBB) {
460     if (CannotProvideElements(MBB, NumElements))
461       return false;
462     MBB = *MBB->pred_begin();
463   }
464 
465   LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
466   return true;
467 }
468 
469 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
470                                     ReachingDefAnalysis *RDA,
471                                     MachineLoopInfo *MLI) {
472   if (Revert)
473     return;
474 
475   if (!End->getOperand(1).isMBB())
476     report_fatal_error("Expected LoopEnd to target basic block");
477 
478   // TODO Maybe there's cases where the target doesn't have to be the header,
479   // but for now be safe and revert.
480   if (End->getOperand(1).getMBB() != ML->getHeader()) {
481     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
482     Revert = true;
483     return;
484   }
485 
486   // The WLS and LE instructions have 12-bits for the label offset. WLS
487   // requires a positive offset, while LE uses negative.
488   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) ||
489       !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) {
490     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
491     Revert = true;
492     return;
493   }
494 
495   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
496       (BBUtils->getOffsetOf(Start) >
497        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
498        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
499     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
500     Revert = true;
501     return;
502   }
503 
504   InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA);
505   if (!InsertPt) {
506     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
507     Revert = true;
508     return;
509   } else
510     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
511 
512   if (!IsTailPredicationLegal()) {
513     LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n");
514     return;
515   }
516 
517   assert(ML->getBlocks().size() == 1 &&
518          "Shouldn't be processing a loop with more than one block");
519   CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI);
520   LLVM_DEBUG(if (CannotTailPredicate)
521              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
522 }
523 
524 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) {
525   // Only support a single vctp.
526   if (isVCTP(MI) && VCTP)
527     return false;
528 
529   // Start a new vpt block when we discover a vpt.
530   if (MI->getOpcode() == ARM::MVE_VPST) {
531     VPTBlocks.emplace_back(MI, CurrentPredicate);
532     CurrentBlock = &VPTBlocks.back();
533     return true;
534   }
535 
536   if (isVCTP(MI))
537     VCTP = MI;
538 
539   unsigned VPROpNum = MI->getNumOperands() - 1;
540   bool IsUse = false;
541   if (MI->getOperand(VPROpNum).isReg() &&
542       MI->getOperand(VPROpNum).getReg() == ARM::VPR &&
543       MI->getOperand(VPROpNum).isUse()) {
544     // If this instruction is predicated by VPR, it will be its last
545     // operand.  Also check that it's only 'Then' predicated.
546     if (!MI->getOperand(VPROpNum-1).isImm() ||
547         MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) {
548       LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: "
549                  << *MI);
550       return false;
551     }
552     CurrentBlock->addInst(MI, CurrentPredicate);
553     IsUse = true;
554   }
555 
556   bool IsDef = false;
557   for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) {
558     const MachineOperand &MO = MI->getOperand(i);
559     if (!MO.isReg() || MO.getReg() != ARM::VPR)
560       continue;
561 
562     if (MO.isDef()) {
563       CurrentPredicate.insert(MI);
564       IsDef = true;
565     } else {
566       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
567       return false;
568     }
569   }
570 
571   // If we find a vpr def that is not already predicated on the vctp, we've
572   // got disjoint predicates that may not be equivalent when we do the
573   // conversion.
574   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
575     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
576     return false;
577   }
578 
579   return true;
580 }
581 
582 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
583   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
584   if (!ST.hasLOB())
585     return false;
586 
587   MF = &mf;
588   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
589 
590   MLI = &getAnalysis<MachineLoopInfo>();
591   RDA = &getAnalysis<ReachingDefAnalysis>();
592   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
593   MRI = &MF->getRegInfo();
594   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
595   TRI = ST.getRegisterInfo();
596   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
597   BBUtils->computeAllBlockSizes();
598   BBUtils->adjustBBOffsetsAfter(&MF->front());
599 
600   bool Changed = false;
601   for (auto ML : *MLI) {
602     if (!ML->getParentLoop())
603       Changed |= ProcessLoop(ML);
604   }
605   Changed |= RevertNonLoops();
606   return Changed;
607 }
608 
609 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
610 
611   bool Changed = false;
612 
613   // Process inner loops first.
614   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
615     Changed |= ProcessLoop(*I);
616 
617   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
618              if (auto *Preheader = ML->getLoopPreheader())
619                dbgs() << " - " << Preheader->getName() << "\n";
620              else if (auto *Preheader = MLI->findLoopPreheader(ML))
621                dbgs() << " - " << Preheader->getName() << "\n";
622              for (auto *MBB : ML->getBlocks())
623                dbgs() << " - " << MBB->getName() << "\n";
624             );
625 
626   // Search the given block for a loop start instruction. If one isn't found,
627   // and there's only one predecessor block, search that one too.
628   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
629     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
630     for (auto &MI : *MBB) {
631       if (isLoopStart(MI))
632         return &MI;
633     }
634     if (MBB->pred_size() == 1)
635       return SearchForStart(*MBB->pred_begin());
636     return nullptr;
637   };
638 
639   LowOverheadLoop LoLoop(ML);
640   // Search the preheader for the start intrinsic.
641   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
642   // with potentially multiple set.loop.iterations, so we need to enable this.
643   if (auto *Preheader = ML->getLoopPreheader())
644     LoLoop.Start = SearchForStart(Preheader);
645   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
646     LoLoop.Start = SearchForStart(Preheader);
647   else
648     return false;
649 
650   // Find the low-overhead loop components and decide whether or not to fall
651   // back to a normal loop. Also look for a vctp instructions and decide
652   // whether we can convert that predicate using tail predication.
653   for (auto *MBB : reverse(ML->getBlocks())) {
654     for (auto &MI : *MBB) {
655       if (MI.getOpcode() == ARM::t2LoopDec)
656         LoLoop.Dec = &MI;
657       else if (MI.getOpcode() == ARM::t2LoopEnd)
658         LoLoop.End = &MI;
659       else if (isLoopStart(MI))
660         LoLoop.Start = &MI;
661       else if (MI.getDesc().isCall()) {
662         // TODO: Though the call will require LE to execute again, does this
663         // mean we should revert? Always executing LE hopefully should be
664         // faster than performing a sub,cmp,br or even subs,br.
665         LoLoop.Revert = true;
666         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
667       } else {
668         // Record VPR defs and build up their corresponding vpt blocks.
669         // Check we know how to tail predicate any mve instructions.
670         LoLoop.AnalyseMVEInst(&MI);
671       }
672 
673       // We need to ensure that LR is not used or defined inbetween LoopDec and
674       // LoopEnd.
675       if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert)
676         continue;
677 
678       // If we find that LR has been written or read between LoopDec and
679       // LoopEnd, expect that the decremented value is being used else where.
680       // Because this value isn't actually going to be produced until the
681       // latch, by LE, we would need to generate a real sub. The value is also
682       // likely to be copied/reloaded for use of LoopEnd - in which in case
683       // we'd need to perform an add because it gets subtracted again by LE!
684       // The other option is to then generate the other form of LE which doesn't
685       // perform the sub.
686       for (auto &MO : MI.operands()) {
687         if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() &&
688             MO.getReg() == ARM::LR) {
689           LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI);
690           LoLoop.Revert = true;
691           break;
692         }
693       }
694     }
695   }
696 
697   LLVM_DEBUG(LoLoop.dump());
698   if (!LoLoop.FoundAllComponents())
699     return false;
700 
701   LoLoop.CheckLegality(BBUtils.get(), RDA, MLI);
702   Expand(LoLoop);
703   return true;
704 }
705 
706 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
707 // beq that branches to the exit branch.
708 // TODO: We could also try to generate a cbz if the value in LR is also in
709 // another low register.
710 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
711   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
712   MachineBasicBlock *MBB = MI->getParent();
713   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
714                                     TII->get(ARM::t2CMPri));
715   MIB.add(MI->getOperand(0));
716   MIB.addImm(0);
717   MIB.addImm(ARMCC::AL);
718   MIB.addReg(ARM::NoRegister);
719 
720   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
721   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
722     ARM::tBcc : ARM::t2Bcc;
723 
724   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
725   MIB.add(MI->getOperand(1));   // branch target
726   MIB.addImm(ARMCC::EQ);        // condition code
727   MIB.addReg(ARM::CPSR);
728   MI->eraseFromParent();
729 }
730 
731 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI,
732                                         bool SetFlags) const {
733   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
734   MachineBasicBlock *MBB = MI->getParent();
735 
736   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
737   if (SetFlags &&
738       (RDA->isRegUsedAfter(MI, ARM::CPSR) ||
739        !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR)))
740       SetFlags = false;
741 
742   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
743                                     TII->get(ARM::t2SUBri));
744   MIB.addDef(ARM::LR);
745   MIB.add(MI->getOperand(1));
746   MIB.add(MI->getOperand(2));
747   MIB.addImm(ARMCC::AL);
748   MIB.addReg(0);
749 
750   if (SetFlags) {
751     MIB.addReg(ARM::CPSR);
752     MIB->getOperand(5).setIsDef(true);
753   } else
754     MIB.addReg(0);
755 
756   MI->eraseFromParent();
757   return SetFlags;
758 }
759 
760 // Generate a subs, or sub and cmp, and a branch instead of an LE.
761 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
762   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
763 
764   MachineBasicBlock *MBB = MI->getParent();
765   // Create cmp
766   if (!SkipCmp) {
767     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
768                                       TII->get(ARM::t2CMPri));
769     MIB.addReg(ARM::LR);
770     MIB.addImm(0);
771     MIB.addImm(ARMCC::AL);
772     MIB.addReg(ARM::NoRegister);
773   }
774 
775   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
776   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
777     ARM::tBcc : ARM::t2Bcc;
778 
779   // Create bne
780   MachineInstrBuilder MIB =
781     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
782   MIB.add(MI->getOperand(1));   // branch target
783   MIB.addImm(ARMCC::NE);        // condition code
784   MIB.addReg(ARM::CPSR);
785   MI->eraseFromParent();
786 }
787 
788 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
789   MachineInstr *InsertPt = LoLoop.InsertPt;
790   MachineInstr *Start = LoLoop.Start;
791   MachineBasicBlock *MBB = InsertPt->getParent();
792   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
793   unsigned Opc = LoLoop.getStartOpcode();
794   MachineOperand &Count = LoLoop.getCount();
795 
796   MachineInstrBuilder MIB =
797     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
798 
799   MIB.addDef(ARM::LR);
800   MIB.add(Count);
801   if (!IsDo)
802     MIB.add(Start->getOperand(1));
803 
804   // When using tail-predication, try to delete the dead code that was used to
805   // calculate the number of loop iterations.
806   if (LoLoop.IsTailPredicationLegal()) {
807     SmallVector<MachineInstr*, 4> Killed;
808     SmallVector<MachineInstr*, 4> Dead;
809     if (auto *Def = RDA->getReachingMIDef(Start,
810                                           Start->getOperand(0).getReg())) {
811       Killed.push_back(Def);
812 
813       while (!Killed.empty()) {
814         MachineInstr *Def = Killed.back();
815         Killed.pop_back();
816         Dead.push_back(Def);
817         for (auto &MO : Def->operands()) {
818           if (!MO.isReg() || !MO.isKill())
819             continue;
820 
821           MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg());
822           if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1)
823             Killed.push_back(Kill);
824         }
825       }
826       for (auto *MI : Dead)
827         MI->eraseFromParent();
828     }
829   }
830 
831   // If we're inserting at a mov lr, then remove it as it's redundant.
832   if (InsertPt != Start)
833     InsertPt->eraseFromParent();
834   Start->eraseFromParent();
835   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
836   return &*MIB;
837 }
838 
839 // Goal is to optimise and clean-up these loops:
840 //
841 //   vector.body:
842 //     renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg
843 //     renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4
844 //     ..
845 //     $lr = MVE_DLSTP_32 renamable $r3
846 //
847 // The SUB is the old update of the loop iteration count expression, which
848 // is no longer needed. This sub is removed when the element count, which is in
849 // r3 in this example, is defined by an instruction in the loop, and it has
850 // no uses.
851 //
852 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
853   Register ElemCount = LoLoop.VCTP->getOperand(1).getReg();
854   MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back();
855 
856   LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n");
857 
858   if (LoLoop.ML->getNumBlocks() != 1) {
859     LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n");
860     return;
861   }
862 
863   LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: ";
864              LoLoop.VCTP->getOperand(1).dump());
865 
866   // Find the definition we are interested in removing, if there is one.
867   MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount);
868   if (!Def)
869     return;
870 
871   // Bail if we define CPSR and it is not dead
872   if (!Def->registerDefIsDead(ARM::CPSR, TRI)) {
873     LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n");
874     return;
875   }
876 
877   // Bail if elemcount is used in exit blocks, i.e. if it is live-in.
878   if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) {
879     LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n");
880     return;
881   }
882 
883   // Bail if there are uses after this Def in the block.
884   SmallVector<MachineInstr*, 4> Uses;
885   RDA->getReachingLocalUses(Def, ElemCount, Uses);
886   if (Uses.size()) {
887     LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n");
888     return;
889   }
890 
891   Uses.clear();
892   RDA->getAllInstWithUseBefore(Def, ElemCount, Uses);
893 
894   // Remove Def if there are no uses, or if the only use is the VCTP
895   // instruction.
896   if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) {
897     LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: ";
898                Def->dump());
899     Def->eraseFromParent();
900   }
901 }
902 
903 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
904   auto RemovePredicate = [](MachineInstr *MI) {
905     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
906     unsigned OpNum = MI->getNumOperands() - 1;
907     assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then &&
908            "Expected Then predicate!");
909     MI->getOperand(OpNum-1).setImm(ARMVCC::None);
910     MI->getOperand(OpNum).setReg(0);
911   };
912 
913   // There are a few scenarios which we have to fix up:
914   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
915   //    defs.
916   // 2) A VPT block which is only predicated by the vctp but has an internal
917   //    vpr def.
918   // 3) A VPT block which is predicated upon the vctp as well as another vpr
919   //    def.
920   // 4) A VPT block which is not predicated upon a vctp, but contains it and
921   //    all instructions within the block are predicated upon in.
922 
923   for (auto &Block : LoLoop.getVPTBlocks()) {
924     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
925     if (Block.HasNonUniformPredicate()) {
926       PredicatedMI *Divergent = Block.getDivergent();
927       if (isVCTP(Divergent->MI)) {
928         // The vctp will be removed, so the size of the vpt block needs to be
929         // modified.
930         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
931         Block.getVPST()->getOperand(0).setImm(Size);
932         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
933       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
934         // The VPT block has a non-uniform predicate but it's entry is guarded
935         // only by a vctp, which means we:
936         // - Need to remove the original vpst.
937         // - Then need to unpredicate any following instructions, until
938         //   we come across the divergent vpr def.
939         // - Insert a new vpst to predicate the instruction(s) that following
940         //   the divergent vpr def.
941         // TODO: We could be producing more VPT blocks than necessary and could
942         // fold the newly created one into a proceeding one.
943         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
944              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
945           RemovePredicate(&*I);
946 
947         unsigned Size = 0;
948         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
949         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
950         MachineInstr *InsertAt = nullptr;
951         while (I != E) {
952           InsertAt = &*I;
953           ++Size;
954           ++I;
955         }
956         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
957                                           InsertAt->getDebugLoc(),
958                                           TII->get(ARM::MVE_VPST));
959         MIB.addImm(getARMVPTBlockMask(Size));
960         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
961         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
962         Block.getVPST()->eraseFromParent();
963       }
964     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
965       // A vpt block which is only predicated upon vctp and has no internal vpr
966       // defs:
967       // - Remove vpst.
968       // - Unpredicate the remaining instructions.
969       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
970       Block.getVPST()->eraseFromParent();
971       for (auto &PredMI : Insts)
972         RemovePredicate(PredMI.MI);
973     }
974   }
975 
976   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
977   LoLoop.VCTP->eraseFromParent();
978 }
979 
980 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
981 
982   // Combine the LoopDec and LoopEnd instructions into LE(TP).
983   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
984     MachineInstr *End = LoLoop.End;
985     MachineBasicBlock *MBB = End->getParent();
986     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
987       ARM::MVE_LETP : ARM::t2LEUpdate;
988     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
989                                       TII->get(Opc));
990     MIB.addDef(ARM::LR);
991     MIB.add(End->getOperand(0));
992     MIB.add(End->getOperand(1));
993     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
994 
995     LoLoop.End->eraseFromParent();
996     LoLoop.Dec->eraseFromParent();
997     return &*MIB;
998   };
999 
1000   // TODO: We should be able to automatically remove these branches before we
1001   // get here - probably by teaching analyzeBranch about the pseudo
1002   // instructions.
1003   // If there is an unconditional branch, after I, that just branches to the
1004   // next block, remove it.
1005   auto RemoveDeadBranch = [](MachineInstr *I) {
1006     MachineBasicBlock *BB = I->getParent();
1007     MachineInstr *Terminator = &BB->instr_back();
1008     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1009       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1010       if (BB->isLayoutSuccessor(Succ)) {
1011         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1012         Terminator->eraseFromParent();
1013       }
1014     }
1015   };
1016 
1017   if (LoLoop.Revert) {
1018     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1019       RevertWhile(LoLoop.Start);
1020     else
1021       LoLoop.Start->eraseFromParent();
1022     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true);
1023     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1024   } else {
1025     LoLoop.Start = ExpandLoopStart(LoLoop);
1026     RemoveDeadBranch(LoLoop.Start);
1027     LoLoop.End = ExpandLoopEnd(LoLoop);
1028     RemoveDeadBranch(LoLoop.End);
1029     if (LoLoop.IsTailPredicationLegal()) {
1030       RemoveLoopUpdate(LoLoop);
1031       ConvertVPTBlocks(LoLoop);
1032     }
1033   }
1034 
1035   PostOrderLoopTraversal DFS(*LoLoop.ML, *MLI);
1036   DFS.ProcessLoop();
1037   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1038   for (auto *MBB : PostOrder)
1039     recomputeLiveIns(*MBB);
1040 
1041   for (auto *MBB : reverse(PostOrder))
1042     recomputeLivenessFlags(*MBB);
1043 }
1044 
1045 bool ARMLowOverheadLoops::RevertNonLoops() {
1046   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1047   bool Changed = false;
1048 
1049   for (auto &MBB : *MF) {
1050     SmallVector<MachineInstr*, 4> Starts;
1051     SmallVector<MachineInstr*, 4> Decs;
1052     SmallVector<MachineInstr*, 4> Ends;
1053 
1054     for (auto &I : MBB) {
1055       if (isLoopStart(I))
1056         Starts.push_back(&I);
1057       else if (I.getOpcode() == ARM::t2LoopDec)
1058         Decs.push_back(&I);
1059       else if (I.getOpcode() == ARM::t2LoopEnd)
1060         Ends.push_back(&I);
1061     }
1062 
1063     if (Starts.empty() && Decs.empty() && Ends.empty())
1064       continue;
1065 
1066     Changed = true;
1067 
1068     for (auto *Start : Starts) {
1069       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1070         RevertWhile(Start);
1071       else
1072         Start->eraseFromParent();
1073     }
1074     for (auto *Dec : Decs)
1075       RevertLoopDec(Dec);
1076 
1077     for (auto *End : Ends)
1078       RevertLoopEnd(End);
1079   }
1080   return Changed;
1081 }
1082 
1083 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1084   return new ARMLowOverheadLoops();
1085 }
1086