xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision b81c57d646e49c15de1b6e2938b8689b7854a02b)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "Thumb2InstrInfo.h"
60 #include "llvm/ADT/SetOperations.h"
61 #include "llvm/ADT/SmallSet.h"
62 #include "llvm/CodeGen/LivePhysRegs.h"
63 #include "llvm/CodeGen/MachineFunctionPass.h"
64 #include "llvm/CodeGen/MachineLoopInfo.h"
65 #include "llvm/CodeGen/MachineLoopUtils.h"
66 #include "llvm/CodeGen/MachineRegisterInfo.h"
67 #include "llvm/CodeGen/Passes.h"
68 #include "llvm/CodeGen/ReachingDefAnalysis.h"
69 #include "llvm/MC/MCInstrDesc.h"
70 
71 using namespace llvm;
72 
73 #define DEBUG_TYPE "arm-low-overhead-loops"
74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
75 
76 namespace {
77 
78   using InstSet = SmallPtrSetImpl<MachineInstr *>;
79 
80   class PostOrderLoopTraversal {
81     MachineLoop &ML;
82     MachineLoopInfo &MLI;
83     SmallPtrSet<MachineBasicBlock*, 4> Visited;
84     SmallVector<MachineBasicBlock*, 4> Order;
85 
86   public:
87     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
88       : ML(ML), MLI(MLI) { }
89 
90     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
91       return Order;
92     }
93 
94     // Visit all the blocks within the loop, as well as exit blocks and any
95     // blocks properly dominating the header.
96     void ProcessLoop() {
97       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
98         (MachineBasicBlock *MBB) -> void {
99         if (Visited.count(MBB))
100           return;
101 
102         Visited.insert(MBB);
103         for (auto *Succ : MBB->successors()) {
104           if (!ML.contains(Succ))
105             continue;
106           Search(Succ);
107         }
108         Order.push_back(MBB);
109       };
110 
111       // Insert exit blocks.
112       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
113       ML.getExitBlocks(ExitBlocks);
114       for (auto *MBB : ExitBlocks)
115         Order.push_back(MBB);
116 
117       // Then add the loop body.
118       Search(ML.getHeader());
119 
120       // Then try the preheader and its predecessors.
121       std::function<void(MachineBasicBlock*)> GetPredecessor =
122         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
123         Order.push_back(MBB);
124         if (MBB->pred_size() == 1)
125           GetPredecessor(*MBB->pred_begin());
126       };
127 
128       if (auto *Preheader = ML.getLoopPreheader())
129         GetPredecessor(Preheader);
130       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
131         GetPredecessor(Preheader);
132     }
133   };
134 
135   struct PredicatedMI {
136     MachineInstr *MI = nullptr;
137     SetVector<MachineInstr*> Predicates;
138 
139   public:
140     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
141       assert(I && "Instruction must not be null!");
142       Predicates.insert(Preds.begin(), Preds.end());
143     }
144   };
145 
146   // Represent a VPT block, a list of instructions that begins with a VPT/VPST
147   // and has a maximum of four proceeding instructions. All instructions within
148   // the block are predicated upon the vpr and we allow instructions to define
149   // the vpr within in the block too.
150   class VPTBlock {
151     // The predicate then instruction, which is either a VPT, or a VPST
152     // instruction.
153     std::unique_ptr<PredicatedMI> PredicateThen;
154     PredicatedMI *Divergent = nullptr;
155     SmallVector<PredicatedMI, 4> Insts;
156 
157   public:
158     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
159       PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
160     }
161 
162     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
163       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
164       if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
165         Divergent = &Insts.back();
166         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
167       }
168       Insts.emplace_back(MI, Preds);
169       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
170     }
171 
172     // Have we found an instruction within the block which defines the vpr? If
173     // so, not all the instructions in the block will have the same predicate.
174     bool HasNonUniformPredicate() const {
175       return Divergent != nullptr;
176     }
177 
178     // Is the given instruction part of the predicate set controlling the entry
179     // to the block.
180     bool IsPredicatedOn(MachineInstr *MI) const {
181       return PredicateThen->Predicates.count(MI);
182     }
183 
184     // Returns true if this is a VPT instruction.
185     bool isVPT() const { return !isVPST(); }
186 
187     // Returns true if this is a VPST instruction.
188     bool isVPST() const {
189       return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
190     }
191 
192     // Is the given instruction the only predicate which controls the entry to
193     // the block.
194     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
195       return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
196     }
197 
198     unsigned size() const { return Insts.size(); }
199     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
200     MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
201     PredicatedMI *getDivergent() const { return Divergent; }
202   };
203 
204   struct LowOverheadLoop {
205 
206     MachineLoop &ML;
207     MachineBasicBlock *Preheader = nullptr;
208     MachineLoopInfo &MLI;
209     ReachingDefAnalysis &RDA;
210     const TargetRegisterInfo &TRI;
211     const ARMBaseInstrInfo &TII;
212     MachineFunction *MF = nullptr;
213     MachineInstr *InsertPt = nullptr;
214     MachineInstr *Start = nullptr;
215     MachineInstr *Dec = nullptr;
216     MachineInstr *End = nullptr;
217     MachineInstr *VCTP = nullptr;
218     MachineOperand TPNumElements;
219     SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
220     VPTBlock *CurrentBlock = nullptr;
221     SetVector<MachineInstr*> CurrentPredicate;
222     SmallVector<VPTBlock, 4> VPTBlocks;
223     SmallPtrSet<MachineInstr*, 4> ToRemove;
224     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
225     bool Revert = false;
226     bool CannotTailPredicate = false;
227 
228     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
229                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
230                     const ARMBaseInstrInfo &TII)
231         : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII),
232           TPNumElements(MachineOperand::CreateImm(0)) {
233       MF = ML.getHeader()->getParent();
234       if (auto *MBB = ML.getLoopPreheader())
235         Preheader = MBB;
236       else if (auto *MBB = MLI.findLoopPreheader(&ML, true))
237         Preheader = MBB;
238     }
239 
240     // If this is an MVE instruction, check that we know how to use tail
241     // predication with it. Record VPT blocks and return whether the
242     // instruction is valid for tail predication.
243     bool ValidateMVEInst(MachineInstr *MI);
244 
245     void AnalyseMVEInst(MachineInstr *MI) {
246       CannotTailPredicate = !ValidateMVEInst(MI);
247     }
248 
249     bool IsTailPredicationLegal() const {
250       // For now, let's keep things really simple and only support a single
251       // block for tail predication.
252       return !Revert && FoundAllComponents() && VCTP &&
253              !CannotTailPredicate && ML.getNumBlocks() == 1;
254     }
255 
256     // Check that the predication in the loop will be equivalent once we
257     // perform the conversion. Also ensure that we can provide the number
258     // of elements to the loop start instruction.
259     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
260 
261     // Check that any values available outside of the loop will be the same
262     // after tail predication conversion.
263     bool ValidateLiveOuts();
264 
265     // Is it safe to define LR with DLS/WLS?
266     // LR can be defined if it is the operand to start, because it's the same
267     // value, or if it's going to be equivalent to the operand to Start.
268     MachineInstr *isSafeToDefineLR();
269 
270     // Check the branch targets are within range and we satisfy our
271     // restrictions.
272     void CheckLegality(ARMBasicBlockUtils *BBUtils);
273 
274     bool FoundAllComponents() const {
275       return Start && Dec && End;
276     }
277 
278     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
279 
280     // Return the operand for the loop start instruction. This will be the loop
281     // iteration count, or the number of elements if we're tail predicating.
282     MachineOperand &getLoopStartOperand() {
283       return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0);
284     }
285 
286     unsigned getStartOpcode() const {
287       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
288       if (!IsTailPredicationLegal())
289         return IsDo ? ARM::t2DLS : ARM::t2WLS;
290 
291       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
292     }
293 
294     void dump() const {
295       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
296       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
297       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
298       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
299       if (!FoundAllComponents())
300         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
301       else if (!(Start && Dec && End))
302         dbgs() << "ARM Loops: Failed to find all loop components.\n";
303     }
304   };
305 
306   class ARMLowOverheadLoops : public MachineFunctionPass {
307     MachineFunction           *MF = nullptr;
308     MachineLoopInfo           *MLI = nullptr;
309     ReachingDefAnalysis       *RDA = nullptr;
310     const ARMBaseInstrInfo    *TII = nullptr;
311     MachineRegisterInfo       *MRI = nullptr;
312     const TargetRegisterInfo  *TRI = nullptr;
313     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
314 
315   public:
316     static char ID;
317 
318     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
319 
320     void getAnalysisUsage(AnalysisUsage &AU) const override {
321       AU.setPreservesCFG();
322       AU.addRequired<MachineLoopInfo>();
323       AU.addRequired<ReachingDefAnalysis>();
324       MachineFunctionPass::getAnalysisUsage(AU);
325     }
326 
327     bool runOnMachineFunction(MachineFunction &MF) override;
328 
329     MachineFunctionProperties getRequiredProperties() const override {
330       return MachineFunctionProperties().set(
331           MachineFunctionProperties::Property::NoVRegs).set(
332           MachineFunctionProperties::Property::TracksLiveness);
333     }
334 
335     StringRef getPassName() const override {
336       return ARM_LOW_OVERHEAD_LOOPS_NAME;
337     }
338 
339   private:
340     bool ProcessLoop(MachineLoop *ML);
341 
342     bool RevertNonLoops();
343 
344     void RevertWhile(MachineInstr *MI) const;
345 
346     bool RevertLoopDec(MachineInstr *MI) const;
347 
348     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
349 
350     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
351 
352     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
353 
354     void Expand(LowOverheadLoop &LoLoop);
355 
356     void IterationCountDCE(LowOverheadLoop &LoLoop);
357   };
358 }
359 
360 char ARMLowOverheadLoops::ID = 0;
361 
362 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
363                 false, false)
364 
365 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
366   // We can define LR because LR already contains the same value.
367   if (Start->getOperand(0).getReg() == ARM::LR)
368     return Start;
369 
370   unsigned CountReg = Start->getOperand(0).getReg();
371   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
372     return MI->getOpcode() == ARM::tMOVr &&
373            MI->getOperand(0).getReg() == ARM::LR &&
374            MI->getOperand(1).getReg() == CountReg &&
375            MI->getOperand(2).getImm() == ARMCC::AL;
376    };
377 
378   MachineBasicBlock *MBB = Start->getParent();
379 
380   // Find an insertion point:
381   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
382   //   to Count before Start, we can insert at that mov.
383   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
384     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
385       return LRDef;
386 
387   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
388   //   to Count after Start, we can insert at that mov.
389   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
390     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
391       return LRDef;
392 
393   // We've found no suitable LR def and Start doesn't use LR directly. Can we
394   // just define LR anyway?
395   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
396 }
397 
398 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
399   assert(VCTP && "VCTP instruction expected but is not set");
400   // All predication within the loop should be based on vctp. If the block
401   // isn't predicated on entry, check whether the vctp is within the block
402   // and that all other instructions are then predicated on it.
403   for (auto &Block : VPTBlocks) {
404     if (Block.IsPredicatedOn(VCTP))
405       continue;
406     if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
407       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
408                         << *Block.getDivergent()->MI);
409       return false;
410     }
411     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
412     for (auto &PredMI : Insts) {
413       // Check the instructions in the block and only allow:
414       //   - VCTPs
415       //   - Instructions predicated on the main VCTP
416       //   - Any VCMP
417       //      - VCMPs just "and" their result with VPR.P0. Whether they are
418       //      located before/after the VCTP is irrelevant - the end result will
419       //      be the same in both cases, so there's no point in requiring them
420       //      to be located after the VCTP!
421       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
422           VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
423         continue;
424       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
425                  << " - which is predicated on:\n";
426                  for (auto *MI : PredMI.Predicates)
427                    dbgs() << "   - " << *MI);
428       return false;
429     }
430   }
431 
432   if (!ValidateLiveOuts()) {
433     LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n");
434     return false;
435   }
436 
437   // For tail predication, we need to provide the number of elements, instead
438   // of the iteration count, to the loop start instruction. The number of
439   // elements is provided to the vctp instruction, so we need to check that
440   // we can use this register at InsertPt.
441   TPNumElements = VCTP->getOperand(1);
442   Register NumElements = TPNumElements.getReg();
443 
444   // If the register is defined within loop, then we can't perform TP.
445   // TODO: Check whether this is just a mov of a register that would be
446   // available.
447   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
448     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
449     return false;
450   }
451 
452   // The element count register maybe defined after InsertPt, in which case we
453   // need to try to move either InsertPt or the def so that the [w|d]lstp can
454   // use the value.
455   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
456 
457   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
458     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
459       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
460         ElemDef->removeFromParent();
461         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
462         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
463                    << *ElemDef);
464       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
465         StartInsertPt->removeFromParent();
466         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
467                               StartInsertPt);
468         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
469       } else {
470         // If we fail to move an instruction and the element count is provided
471         // by a mov, use the mov operand if it will have the same value at the
472         // insertion point
473         MachineOperand Operand = ElemDef->getOperand(1);
474         if (isMovRegOpcode(ElemDef->getOpcode()) &&
475             RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) ==
476                 RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) {
477           TPNumElements = Operand;
478           NumElements = TPNumElements.getReg();
479         } else {
480           LLVM_DEBUG(dbgs()
481                      << "ARM Loops: Unable to move element count to loop "
482                      << "start instruction.\n");
483           return false;
484         }
485       }
486     }
487   }
488 
489   // Especially in the case of while loops, InsertBB may not be the
490   // preheader, so we need to check that the register isn't redefined
491   // before entering the loop.
492   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
493                                       Register NumElements) {
494     // NumElements is redefined in this block.
495     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
496       return true;
497 
498     // Don't continue searching up through multiple predecessors.
499     if (MBB->pred_size() > 1)
500       return true;
501 
502     return false;
503   };
504 
505   // First, find the block that looks like the preheader.
506   MachineBasicBlock *MBB = Preheader;
507   if (!MBB) {
508     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
509     return false;
510   }
511 
512   // Then search backwards for a def, until we get to InsertBB.
513   while (MBB != InsertBB) {
514     if (CannotProvideElements(MBB, NumElements)) {
515       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
516       return false;
517     }
518     MBB = *MBB->pred_begin();
519   }
520 
521   // Check that the value change of the element count is what we expect and
522   // that the predication will be equivalent. For this we need:
523   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
524   // and we can also allow register copies within the chain too.
525   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
526     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
527   };
528 
529   MBB = VCTP->getParent();
530   // Remove modifications to the element count since they have no purpose in a
531   // tail predicated loop. Explicitly refer to the vctp operand no matter which
532   // register NumElements has been assigned to, since that is what the
533   // modifications will be using
534   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(),
535                                              VCTP->getOperand(1).getReg())) {
536     SmallPtrSet<MachineInstr*, 2> ElementChain;
537     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
538     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
539 
540     Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
541 
542     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
543       bool FoundSub = false;
544 
545       for (auto *MI : ElementChain) {
546         if (isMovRegOpcode(MI->getOpcode()))
547           continue;
548 
549         if (isSubImmOpcode(MI->getOpcode())) {
550           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
551             return false;
552           FoundSub = true;
553         } else
554           return false;
555       }
556 
557       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
558                  for (auto *MI : ElementChain)
559                    dbgs() << " - " << *MI);
560       ToRemove.insert(ElementChain.begin(), ElementChain.end());
561     }
562   }
563   return true;
564 }
565 
566 static bool isVectorPredicated(MachineInstr *MI) {
567   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
568   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
569 }
570 
571 static bool isRegInClass(const MachineOperand &MO,
572                          const TargetRegisterClass *Class) {
573   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
574 }
575 
576 // MVE 'narrowing' operate on half a lane, reading from half and writing
577 // to half, which are referred to has the top and bottom half. The other
578 // half retains its previous value.
579 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
580   const MCInstrDesc &MCID = MI.getDesc();
581   uint64_t Flags = MCID.TSFlags;
582   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
583 }
584 
585 // Some MVE instructions read from the top/bottom halves of their operand(s)
586 // and generate a vector result with result elements that are double the
587 // width of the input.
588 static bool producesDoubleWidthResult(const MachineInstr &MI) {
589   const MCInstrDesc &MCID = MI.getDesc();
590   uint64_t Flags = MCID.TSFlags;
591   return (Flags & ARMII::DoubleWidthResult) != 0;
592 }
593 
594 static bool isHorizontalReduction(const MachineInstr &MI) {
595   const MCInstrDesc &MCID = MI.getDesc();
596   uint64_t Flags = MCID.TSFlags;
597   return (Flags & ARMII::HorizontalReduction) != 0;
598 }
599 
600 // Can this instruction generate a non-zero result when given only zeroed
601 // operands? This allows us to know that, given operands with false bytes
602 // zeroed by masked loads, that the result will also contain zeros in those
603 // bytes.
604 static bool canGenerateNonZeros(const MachineInstr &MI) {
605 
606   // Check for instructions which can write into a larger element size,
607   // possibly writing into a previous zero'd lane.
608   if (producesDoubleWidthResult(MI))
609     return true;
610 
611   switch (MI.getOpcode()) {
612   default:
613     break;
614   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
615   // fp16 -> fp32 vector conversions.
616   // Instructions that perform a NOT will generate 1s from 0s.
617   case ARM::MVE_VMVN:
618   case ARM::MVE_VORN:
619   // Count leading zeros will do just that!
620   case ARM::MVE_VCLZs8:
621   case ARM::MVE_VCLZs16:
622   case ARM::MVE_VCLZs32:
623     return true;
624   }
625   return false;
626 }
627 
628 // Look at its register uses to see if it only can only receive zeros
629 // into its false lanes which would then produce zeros. Also check that
630 // the output register is also defined by an FalseLanesZero instruction
631 // so that if tail-predication happens, the lanes that aren't updated will
632 // still be zeros.
633 static bool producesFalseLanesZero(MachineInstr &MI,
634                                    const TargetRegisterClass *QPRs,
635                                    const ReachingDefAnalysis &RDA,
636                                    InstSet &FalseLanesZero) {
637   if (canGenerateNonZeros(MI))
638     return false;
639 
640   bool isPredicated = isVectorPredicated(&MI);
641   // Predicated loads will write zeros to the falsely predicated bytes of the
642   // destination register.
643   if (MI.mayLoad())
644     return isPredicated;
645 
646   auto IsZeroInit = [](MachineInstr *Def) {
647     return !isVectorPredicated(Def) &&
648            Def->getOpcode() == ARM::MVE_VMOVimmi32 &&
649            Def->getOperand(1).getImm() == 0;
650   };
651 
652   bool AllowScalars = isHorizontalReduction(MI);
653   for (auto &MO : MI.operands()) {
654     if (!MO.isReg() || !MO.getReg())
655       continue;
656     if (!isRegInClass(MO, QPRs) && AllowScalars)
657       continue;
658 
659     // Check that this instruction will produce zeros in its false lanes:
660     // - If it only consumes false lanes zero or constant 0 (vmov #0)
661     // - If it's predicated, it only matters that it's def register already has
662     //   false lane zeros, so we can ignore the uses.
663     SmallPtrSet<MachineInstr *, 2> Defs;
664     RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs);
665     for (auto *Def : Defs) {
666       if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def))
667         continue;
668       if (MO.isUse() && isPredicated)
669         continue;
670       return false;
671     }
672   }
673   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
674   return true;
675 }
676 
677 bool LowOverheadLoop::ValidateLiveOuts() {
678   // We want to find out if the tail-predicated version of this loop will
679   // produce the same values as the loop in its original form. For this to
680   // be true, the newly inserted implicit predication must not change the
681   // the (observable) results.
682   // We're doing this because many instructions in the loop will not be
683   // predicated and so the conversion from VPT predication to tail-predication
684   // can result in different values being produced; due to the tail-predication
685   // preventing many instructions from updating their falsely predicated
686   // lanes. This analysis assumes that all the instructions perform lane-wise
687   // operations and don't perform any exchanges.
688   // A masked load, whether through VPT or tail predication, will write zeros
689   // to any of the falsely predicated bytes. So, from the loads, we know that
690   // the false lanes are zeroed and here we're trying to track that those false
691   // lanes remain zero, or where they change, the differences are masked away
692   // by their user(s).
693   // All MVE stores have to be predicated, so we know that any predicate load
694   // operands, or stored results are equivalent already. Other explicitly
695   // predicated instructions will perform the same operation in the original
696   // loop and the tail-predicated form too. Because of this, we can insert
697   // loads, stores and other predicated instructions into our Predicated
698   // set and build from there.
699   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
700   SetVector<MachineInstr *> FalseLanesUnknown;
701   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
702   SmallPtrSet<MachineInstr *, 4> Predicated;
703   MachineBasicBlock *Header = ML.getHeader();
704 
705   for (auto &MI : *Header) {
706     const MCInstrDesc &MCID = MI.getDesc();
707     uint64_t Flags = MCID.TSFlags;
708     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
709       continue;
710 
711     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
712       continue;
713 
714     bool isPredicated = isVectorPredicated(&MI);
715     bool retainsOrReduces =
716       retainsPreviousHalfElement(MI) || isHorizontalReduction(MI);
717 
718     if (isPredicated)
719       Predicated.insert(&MI);
720     if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero))
721       FalseLanesZero.insert(&MI);
722     else if (MI.getNumDefs() == 0)
723       continue;
724     else if (!isPredicated && retainsOrReduces)
725       return false;
726     else if (!isPredicated)
727       FalseLanesUnknown.insert(&MI);
728   }
729 
730   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
731                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
732     SmallPtrSet<MachineInstr *, 2> Uses;
733     RDA.getGlobalUses(MI, MO.getReg(), Uses);
734     for (auto *Use : Uses) {
735       if (Use != MI && !Predicated.count(Use))
736         return false;
737     }
738     return true;
739   };
740 
741   // Visit the unknowns in reverse so that we can start at the values being
742   // stored and then we can work towards the leaves, hopefully adding more
743   // instructions to Predicated. Successfully terminating the loop means that
744   // all the unknown values have to found to be masked by predicated user(s).
745   // For any unpredicated values, we store them in NonPredicated so that we
746   // can later check whether these form a reduction.
747   SmallPtrSet<MachineInstr*, 2> NonPredicated;
748   for (auto *MI : reverse(FalseLanesUnknown)) {
749     for (auto &MO : MI->operands()) {
750       if (!isRegInClass(MO, QPRs) || !MO.isDef())
751         continue;
752       if (!HasPredicatedUsers(MI, MO, Predicated)) {
753         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
754                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
755         NonPredicated.insert(MI);
756         break;
757       }
758     }
759     // Any unknown false lanes have been masked away by the user(s).
760     if (!NonPredicated.contains(MI))
761       Predicated.insert(MI);
762   }
763 
764   SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
765   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
766   ML.getExitBlocks(ExitBlocks);
767   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
768   assert(ExitBlocks.size() == 1 && "Expected a single exit block");
769   MachineBasicBlock *ExitBB = ExitBlocks.front();
770   for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
771     // TODO: Instead of blocking predication, we could move the vctp to the exit
772     // block and calculate it's operand there in or the preheader.
773     if (RegMask.PhysReg == ARM::VPR)
774       return false;
775     // Check Q-regs that are live in the exit blocks. We don't collect scalars
776     // because they won't be affected by lane predication.
777     if (QPRs->contains(RegMask.PhysReg))
778       if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
779         LiveOutMIs.insert(MI);
780   }
781 
782   // We've already validated that any VPT predication within the loop will be
783   // equivalent when we perform the predication transformation; so we know that
784   // any VPT predicated instruction is predicated upon VCTP. Any live-out
785   // instruction needs to be predicated, so check this here. The instructions
786   // in NonPredicated have been found to be a reduction that we can ensure its
787   // legality.
788   for (auto *MI : LiveOutMIs) {
789     if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) {
790       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to handle live out: " << *MI);
791       return false;
792     }
793   }
794 
795   return true;
796 }
797 
798 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
799   if (Revert)
800     return;
801 
802   if (!End->getOperand(1).isMBB())
803     report_fatal_error("Expected LoopEnd to target basic block");
804 
805   // TODO Maybe there's cases where the target doesn't have to be the header,
806   // but for now be safe and revert.
807   if (End->getOperand(1).getMBB() != ML.getHeader()) {
808     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
809     Revert = true;
810     return;
811   }
812 
813   // The WLS and LE instructions have 12-bits for the label offset. WLS
814   // requires a positive offset, while LE uses negative.
815   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
816       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
817     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
818     Revert = true;
819     return;
820   }
821 
822   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
823       (BBUtils->getOffsetOf(Start) >
824        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
825        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
826     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
827     Revert = true;
828     return;
829   }
830 
831   InsertPt = Revert ? nullptr : isSafeToDefineLR();
832   if (!InsertPt) {
833     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
834     Revert = true;
835     return;
836   } else
837     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
838 
839   if (!IsTailPredicationLegal()) {
840     LLVM_DEBUG(if (!VCTP)
841                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
842                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
843     return;
844   }
845 
846   assert(ML.getBlocks().size() == 1 &&
847          "Shouldn't be processing a loop with more than one block");
848   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
849   LLVM_DEBUG(if (CannotTailPredicate)
850              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
851 }
852 
853 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
854   if (CannotTailPredicate)
855     return false;
856 
857   if (isVCTP(MI)) {
858     // If we find another VCTP, check whether it uses the same value as the main VCTP.
859     // If it does, store it in the SecondaryVCTPs set, else refuse it.
860     if (VCTP) {
861       if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
862           !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
863         LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
864                              "definition from the main VCTP");
865         return false;
866       }
867       LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
868       SecondaryVCTPs.insert(MI);
869     } else {
870       LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
871       VCTP = MI;
872     }
873   } else if (isVPTOpcode(MI->getOpcode())) {
874     if (MI->getOpcode() != ARM::MVE_VPST) {
875       assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
876              "VPT does not implicitly define VPR?!");
877       CurrentPredicate.insert(MI);
878     }
879 
880     VPTBlocks.emplace_back(MI, CurrentPredicate);
881     CurrentBlock = &VPTBlocks.back();
882     return true;
883   } else if (MI->getOpcode() == ARM::MVE_VPSEL ||
884              MI->getOpcode() == ARM::MVE_VPNOT) {
885     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
886     // 1) It will use the VPR as a predicate operand, but doesn't have to be
887     //    instead a VPT block, which means we can assert while building up
888     //    the VPT block because we don't find another VPT or VPST to being a new
889     //    one.
890     // 2) VPSEL still requires a VPR operand even after tail predicating,
891     //    which means we can't remove it unless there is another
892     //    instruction, such as vcmp, that can provide the VPR def.
893     return false;
894   }
895 
896   bool IsUse = false;
897   bool IsDef = false;
898   const MCInstrDesc &MCID = MI->getDesc();
899   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
900     const MachineOperand &MO = MI->getOperand(i);
901     if (!MO.isReg() || MO.getReg() != ARM::VPR)
902       continue;
903 
904     if (MO.isDef()) {
905       CurrentPredicate.insert(MI);
906       IsDef = true;
907     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
908       CurrentBlock->addInst(MI, CurrentPredicate);
909       IsUse = true;
910     } else {
911       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
912       return false;
913     }
914   }
915 
916   // If we find a vpr def that is not already predicated on the vctp, we've
917   // got disjoint predicates that may not be equivalent when we do the
918   // conversion.
919   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
920     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
921     return false;
922   }
923 
924   uint64_t Flags = MCID.TSFlags;
925   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
926     return true;
927 
928   // If we find an instruction that has been marked as not valid for tail
929   // predication, only allow the instruction if it's contained within a valid
930   // VPT block.
931   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
932     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
933     return false;
934   }
935 
936   // If the instruction is already explicitly predicated, then the conversion
937   // will be fine, but ensure that all store operations are predicated.
938   return !IsUse && MI->mayStore() ? false : true;
939 }
940 
941 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
942   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
943   if (!ST.hasLOB())
944     return false;
945 
946   MF = &mf;
947   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
948 
949   MLI = &getAnalysis<MachineLoopInfo>();
950   RDA = &getAnalysis<ReachingDefAnalysis>();
951   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
952   MRI = &MF->getRegInfo();
953   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
954   TRI = ST.getRegisterInfo();
955   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
956   BBUtils->computeAllBlockSizes();
957   BBUtils->adjustBBOffsetsAfter(&MF->front());
958 
959   bool Changed = false;
960   for (auto ML : *MLI) {
961     if (!ML->getParentLoop())
962       Changed |= ProcessLoop(ML);
963   }
964   Changed |= RevertNonLoops();
965   return Changed;
966 }
967 
968 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
969 
970   bool Changed = false;
971 
972   // Process inner loops first.
973   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
974     Changed |= ProcessLoop(*I);
975 
976   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
977              if (auto *Preheader = ML->getLoopPreheader())
978                dbgs() << " - " << Preheader->getName() << "\n";
979              else if (auto *Preheader = MLI->findLoopPreheader(ML))
980                dbgs() << " - " << Preheader->getName() << "\n";
981              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
982                dbgs() << " - " << Preheader->getName() << "\n";
983              for (auto *MBB : ML->getBlocks())
984                dbgs() << " - " << MBB->getName() << "\n";
985             );
986 
987   // Search the given block for a loop start instruction. If one isn't found,
988   // and there's only one predecessor block, search that one too.
989   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
990     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
991     for (auto &MI : *MBB) {
992       if (isLoopStart(MI))
993         return &MI;
994     }
995     if (MBB->pred_size() == 1)
996       return SearchForStart(*MBB->pred_begin());
997     return nullptr;
998   };
999 
1000   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
1001   // Search the preheader for the start intrinsic.
1002   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
1003   // with potentially multiple set.loop.iterations, so we need to enable this.
1004   if (LoLoop.Preheader)
1005     LoLoop.Start = SearchForStart(LoLoop.Preheader);
1006   else
1007     return false;
1008 
1009   // Find the low-overhead loop components and decide whether or not to fall
1010   // back to a normal loop. Also look for a vctp instructions and decide
1011   // whether we can convert that predicate using tail predication.
1012   for (auto *MBB : reverse(ML->getBlocks())) {
1013     for (auto &MI : *MBB) {
1014       if (MI.isDebugValue())
1015         continue;
1016       else if (MI.getOpcode() == ARM::t2LoopDec)
1017         LoLoop.Dec = &MI;
1018       else if (MI.getOpcode() == ARM::t2LoopEnd)
1019         LoLoop.End = &MI;
1020       else if (isLoopStart(MI))
1021         LoLoop.Start = &MI;
1022       else if (MI.getDesc().isCall()) {
1023         // TODO: Though the call will require LE to execute again, does this
1024         // mean we should revert? Always executing LE hopefully should be
1025         // faster than performing a sub,cmp,br or even subs,br.
1026         LoLoop.Revert = true;
1027         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
1028       } else {
1029         // Record VPR defs and build up their corresponding vpt blocks.
1030         // Check we know how to tail predicate any mve instructions.
1031         LoLoop.AnalyseMVEInst(&MI);
1032       }
1033     }
1034   }
1035 
1036   LLVM_DEBUG(LoLoop.dump());
1037   if (!LoLoop.FoundAllComponents()) {
1038     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
1039     return false;
1040   }
1041 
1042   // Check that the only instruction using LoopDec is LoopEnd.
1043   // TODO: Check for copy chains that really have no effect.
1044   SmallPtrSet<MachineInstr*, 2> Uses;
1045   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
1046   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1047     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1048     LoLoop.Revert = true;
1049   }
1050   LoLoop.CheckLegality(BBUtils.get());
1051   Expand(LoLoop);
1052   return true;
1053 }
1054 
1055 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1056 // beq that branches to the exit branch.
1057 // TODO: We could also try to generate a cbz if the value in LR is also in
1058 // another low register.
1059 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1060   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1061   MachineBasicBlock *MBB = MI->getParent();
1062   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1063                                     TII->get(ARM::t2CMPri));
1064   MIB.add(MI->getOperand(0));
1065   MIB.addImm(0);
1066   MIB.addImm(ARMCC::AL);
1067   MIB.addReg(ARM::NoRegister);
1068 
1069   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1070   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1071     ARM::tBcc : ARM::t2Bcc;
1072 
1073   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1074   MIB.add(MI->getOperand(1));   // branch target
1075   MIB.addImm(ARMCC::EQ);        // condition code
1076   MIB.addReg(ARM::CPSR);
1077   MI->eraseFromParent();
1078 }
1079 
1080 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1081   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1082   MachineBasicBlock *MBB = MI->getParent();
1083   SmallPtrSet<MachineInstr*, 1> Ignore;
1084   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1085     if (I->getOpcode() == ARM::t2LoopEnd) {
1086       Ignore.insert(&*I);
1087       break;
1088     }
1089   }
1090 
1091   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1092   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1093 
1094   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1095                                     TII->get(ARM::t2SUBri));
1096   MIB.addDef(ARM::LR);
1097   MIB.add(MI->getOperand(1));
1098   MIB.add(MI->getOperand(2));
1099   MIB.addImm(ARMCC::AL);
1100   MIB.addReg(0);
1101 
1102   if (SetFlags) {
1103     MIB.addReg(ARM::CPSR);
1104     MIB->getOperand(5).setIsDef(true);
1105   } else
1106     MIB.addReg(0);
1107 
1108   MI->eraseFromParent();
1109   return SetFlags;
1110 }
1111 
1112 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1113 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1114   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1115 
1116   MachineBasicBlock *MBB = MI->getParent();
1117   // Create cmp
1118   if (!SkipCmp) {
1119     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1120                                       TII->get(ARM::t2CMPri));
1121     MIB.addReg(ARM::LR);
1122     MIB.addImm(0);
1123     MIB.addImm(ARMCC::AL);
1124     MIB.addReg(ARM::NoRegister);
1125   }
1126 
1127   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1128   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1129     ARM::tBcc : ARM::t2Bcc;
1130 
1131   // Create bne
1132   MachineInstrBuilder MIB =
1133     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1134   MIB.add(MI->getOperand(1));   // branch target
1135   MIB.addImm(ARMCC::NE);        // condition code
1136   MIB.addReg(ARM::CPSR);
1137   MI->eraseFromParent();
1138 }
1139 
1140 // Perform dead code elimation on the loop iteration count setup expression.
1141 // If we are tail-predicating, the number of elements to be processed is the
1142 // operand of the VCTP instruction in the vector body, see getCount(), which is
1143 // register $r3 in this example:
1144 //
1145 //   $lr = big-itercount-expression
1146 //   ..
1147 //   t2DoLoopStart renamable $lr
1148 //   vector.body:
1149 //     ..
1150 //     $vpr = MVE_VCTP32 renamable $r3
1151 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1152 //     t2LoopEnd renamable $lr, %vector.body
1153 //     tB %end
1154 //
1155 // What we would like achieve here is to replace the do-loop start pseudo
1156 // instruction t2DoLoopStart with:
1157 //
1158 //    $lr = MVE_DLSTP_32 killed renamable $r3
1159 //
1160 // Thus, $r3 which defines the number of elements, is written to $lr,
1161 // and then we want to delete the whole chain that used to define $lr,
1162 // see the comment below how this chain could look like.
1163 //
1164 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1165   if (!LoLoop.IsTailPredicationLegal())
1166     return;
1167 
1168   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1169 
1170   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1171   if (!Def) {
1172     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1173     return;
1174   }
1175 
1176   // Collect and remove the users of iteration count.
1177   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1178                                             LoLoop.End, LoLoop.InsertPt };
1179   SmallPtrSet<MachineInstr*, 2> Remove;
1180   if (RDA->isSafeToRemove(Def, Remove, Killed))
1181     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1182   else {
1183     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1184     return;
1185   }
1186 
1187   // Collect the dead code and the MBBs in which they reside.
1188   RDA->collectKilledOperands(Def, Killed);
1189   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1190   for (auto *MI : Killed)
1191     BasicBlocks.insert(MI->getParent());
1192 
1193   // Collect IT blocks in all affected basic blocks.
1194   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1195   for (auto *MBB : BasicBlocks) {
1196     for (auto &MI : *MBB) {
1197       if (MI.getOpcode() != ARM::t2IT)
1198         continue;
1199       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1200     }
1201   }
1202 
1203   // If we're removing all of the instructions within an IT block, then
1204   // also remove the IT instruction.
1205   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1206   for (auto *MI : Killed) {
1207     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1208       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1209       auto &CurrentBlock = ITBlocks[IT];
1210       CurrentBlock.erase(MI);
1211       if (CurrentBlock.empty())
1212         ModifiedITs.erase(IT);
1213       else
1214         ModifiedITs.insert(IT);
1215     }
1216   }
1217 
1218   // Delete the killed instructions only if we don't have any IT blocks that
1219   // need to be modified because we need to fixup the mask.
1220   // TODO: Handle cases where IT blocks are modified.
1221   if (ModifiedITs.empty()) {
1222     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1223                for (auto *MI : Killed)
1224                  dbgs() << " - " << *MI);
1225     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1226   } else
1227     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1228 }
1229 
1230 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1231   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1232   // When using tail-predication, try to delete the dead code that was used to
1233   // calculate the number of loop iterations.
1234   IterationCountDCE(LoLoop);
1235 
1236   MachineInstr *InsertPt = LoLoop.InsertPt;
1237   MachineInstr *Start = LoLoop.Start;
1238   MachineBasicBlock *MBB = InsertPt->getParent();
1239   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1240   unsigned Opc = LoLoop.getStartOpcode();
1241   MachineOperand &Count = LoLoop.getLoopStartOperand();
1242 
1243   MachineInstrBuilder MIB =
1244     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1245 
1246   MIB.addDef(ARM::LR);
1247   MIB.add(Count);
1248   if (!IsDo)
1249     MIB.add(Start->getOperand(1));
1250 
1251   // If we're inserting at a mov lr, then remove it as it's redundant.
1252   if (InsertPt != Start)
1253     LoLoop.ToRemove.insert(InsertPt);
1254   LoLoop.ToRemove.insert(Start);
1255   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1256   return &*MIB;
1257 }
1258 
1259 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1260   auto RemovePredicate = [](MachineInstr *MI) {
1261     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1262     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1263       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1264              "Expected Then predicate!");
1265       MI->getOperand(PIdx).setImm(ARMVCC::None);
1266       MI->getOperand(PIdx+1).setReg(0);
1267     } else
1268       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1269   };
1270 
1271   // There are a few scenarios which we have to fix up:
1272   // 1. VPT Blocks with non-uniform predicates:
1273   //    - a. When the divergent instruction is a vctp
1274   //    - b. When the block uses a vpst, and is only predicated on the vctp
1275   //    - c. When the block uses a vpt and (optionally) contains one or more
1276   //         vctp.
1277   // 2. VPT Blocks with uniform predicates:
1278   //    - a. The block uses a vpst, and is only predicated on the vctp
1279   for (auto &Block : LoLoop.getVPTBlocks()) {
1280     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1281     if (Block.HasNonUniformPredicate()) {
1282       PredicatedMI *Divergent = Block.getDivergent();
1283       if (isVCTP(Divergent->MI)) {
1284         // The vctp will be removed, so the block mask of the vp(s)t will need
1285         // to be recomputed.
1286         LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1287       } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1288         // The VPT block has a non-uniform predicate but it uses a vpst and its
1289         // entry is guarded only by a vctp, which means we:
1290         // - Need to remove the original vpst.
1291         // - Then need to unpredicate any following instructions, until
1292         //   we come across the divergent vpr def.
1293         // - Insert a new vpst to predicate the instruction(s) that following
1294         //   the divergent vpr def.
1295         // TODO: We could be producing more VPT blocks than necessary and could
1296         // fold the newly created one into a proceeding one.
1297         for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
1298              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1299           RemovePredicate(&*I);
1300 
1301         unsigned Size = 0;
1302         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1303         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1304         MachineInstr *InsertAt = nullptr;
1305         while (I != E) {
1306           InsertAt = &*I;
1307           ++Size;
1308           ++I;
1309         }
1310         // Create a VPST (with a null mask for now, we'll recompute it later).
1311         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1312                                           InsertAt->getDebugLoc(),
1313                                           TII->get(ARM::MVE_VPST));
1314         MIB.addImm(0);
1315         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1316         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1317         LoLoop.ToRemove.insert(Block.getPredicateThen());
1318         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1319       }
1320       // Else, if the block uses a vpt, iterate over the block, removing the
1321       // extra VCTPs it may contain.
1322       else if (Block.isVPT()) {
1323         bool RemovedVCTP = false;
1324         for (PredicatedMI &Elt : Block.getInsts()) {
1325           MachineInstr *MI = Elt.MI;
1326           if (isVCTP(MI)) {
1327             LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
1328             LoLoop.ToRemove.insert(MI);
1329             RemovedVCTP = true;
1330             continue;
1331           }
1332         }
1333         if (RemovedVCTP)
1334           LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1335       }
1336     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
1337       // A vpt block starting with VPST, is only predicated upon vctp and has no
1338       // internal vpr defs:
1339       // - Remove vpst.
1340       // - Unpredicate the remaining instructions.
1341       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1342       LoLoop.ToRemove.insert(Block.getPredicateThen());
1343       for (auto &PredMI : Insts)
1344         RemovePredicate(PredMI.MI);
1345     }
1346   }
1347   LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
1348   // Remove the "main" VCTP
1349   LoLoop.ToRemove.insert(LoLoop.VCTP);
1350   LLVM_DEBUG(dbgs() << "    " << *LoLoop.VCTP);
1351   // Remove remaining secondary VCTPs
1352   for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
1353     // All VCTPs that aren't marked for removal yet should be unpredicated ones.
1354     // The predicated ones should have already been marked for removal when
1355     // visiting the VPT blocks.
1356     if (LoLoop.ToRemove.insert(VCTP).second) {
1357       assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
1358              "Removing Predicated VCTP without updating the block mask!");
1359       LLVM_DEBUG(dbgs() << "    " << *VCTP);
1360     }
1361   }
1362 }
1363 
1364 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1365 
1366   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1367   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1368     MachineInstr *End = LoLoop.End;
1369     MachineBasicBlock *MBB = End->getParent();
1370     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1371       ARM::MVE_LETP : ARM::t2LEUpdate;
1372     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1373                                       TII->get(Opc));
1374     MIB.addDef(ARM::LR);
1375     MIB.add(End->getOperand(0));
1376     MIB.add(End->getOperand(1));
1377     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1378     LoLoop.ToRemove.insert(LoLoop.Dec);
1379     LoLoop.ToRemove.insert(End);
1380     return &*MIB;
1381   };
1382 
1383   // TODO: We should be able to automatically remove these branches before we
1384   // get here - probably by teaching analyzeBranch about the pseudo
1385   // instructions.
1386   // If there is an unconditional branch, after I, that just branches to the
1387   // next block, remove it.
1388   auto RemoveDeadBranch = [](MachineInstr *I) {
1389     MachineBasicBlock *BB = I->getParent();
1390     MachineInstr *Terminator = &BB->instr_back();
1391     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1392       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1393       if (BB->isLayoutSuccessor(Succ)) {
1394         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1395         Terminator->eraseFromParent();
1396       }
1397     }
1398   };
1399 
1400   if (LoLoop.Revert) {
1401     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1402       RevertWhile(LoLoop.Start);
1403     else
1404       LoLoop.Start->eraseFromParent();
1405     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1406     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1407   } else {
1408     LoLoop.Start = ExpandLoopStart(LoLoop);
1409     RemoveDeadBranch(LoLoop.Start);
1410     LoLoop.End = ExpandLoopEnd(LoLoop);
1411     RemoveDeadBranch(LoLoop.End);
1412     if (LoLoop.IsTailPredicationLegal())
1413       ConvertVPTBlocks(LoLoop);
1414     for (auto *I : LoLoop.ToRemove) {
1415       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1416       I->eraseFromParent();
1417     }
1418     for (auto *I : LoLoop.BlockMasksToRecompute) {
1419       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1420       recomputeVPTBlockMask(*I);
1421       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1422     }
1423   }
1424 
1425   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1426   DFS.ProcessLoop();
1427   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1428   for (auto *MBB : PostOrder) {
1429     recomputeLiveIns(*MBB);
1430     // FIXME: For some reason, the live-in print order is non-deterministic for
1431     // our tests and I can't out why... So just sort them.
1432     MBB->sortUniqueLiveIns();
1433   }
1434 
1435   for (auto *MBB : reverse(PostOrder))
1436     recomputeLivenessFlags(*MBB);
1437 
1438   // We've moved, removed and inserted new instructions, so update RDA.
1439   RDA->reset();
1440 }
1441 
1442 bool ARMLowOverheadLoops::RevertNonLoops() {
1443   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1444   bool Changed = false;
1445 
1446   for (auto &MBB : *MF) {
1447     SmallVector<MachineInstr*, 4> Starts;
1448     SmallVector<MachineInstr*, 4> Decs;
1449     SmallVector<MachineInstr*, 4> Ends;
1450 
1451     for (auto &I : MBB) {
1452       if (isLoopStart(I))
1453         Starts.push_back(&I);
1454       else if (I.getOpcode() == ARM::t2LoopDec)
1455         Decs.push_back(&I);
1456       else if (I.getOpcode() == ARM::t2LoopEnd)
1457         Ends.push_back(&I);
1458     }
1459 
1460     if (Starts.empty() && Decs.empty() && Ends.empty())
1461       continue;
1462 
1463     Changed = true;
1464 
1465     for (auto *Start : Starts) {
1466       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1467         RevertWhile(Start);
1468       else
1469         Start->eraseFromParent();
1470     }
1471     for (auto *Dec : Decs)
1472       RevertLoopDec(Dec);
1473 
1474     for (auto *End : Ends)
1475       RevertLoopEnd(End);
1476   }
1477   return Changed;
1478 }
1479 
1480 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1481   return new ARMLowOverheadLoops();
1482 }
1483