xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision b30adfb5295eaa894669e62eec627adda8820a43)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "Thumb2InstrInfo.h"
60 #include "llvm/ADT/SetOperations.h"
61 #include "llvm/ADT/SmallSet.h"
62 #include "llvm/CodeGen/LivePhysRegs.h"
63 #include "llvm/CodeGen/MachineFunctionPass.h"
64 #include "llvm/CodeGen/MachineLoopInfo.h"
65 #include "llvm/CodeGen/MachineLoopUtils.h"
66 #include "llvm/CodeGen/MachineRegisterInfo.h"
67 #include "llvm/CodeGen/Passes.h"
68 #include "llvm/CodeGen/ReachingDefAnalysis.h"
69 #include "llvm/MC/MCInstrDesc.h"
70 
71 using namespace llvm;
72 
73 #define DEBUG_TYPE "arm-low-overhead-loops"
74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
75 
76 namespace {
77 
78   using InstSet = SmallPtrSetImpl<MachineInstr *>;
79 
80   class PostOrderLoopTraversal {
81     MachineLoop &ML;
82     MachineLoopInfo &MLI;
83     SmallPtrSet<MachineBasicBlock*, 4> Visited;
84     SmallVector<MachineBasicBlock*, 4> Order;
85 
86   public:
87     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
88       : ML(ML), MLI(MLI) { }
89 
90     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
91       return Order;
92     }
93 
94     // Visit all the blocks within the loop, as well as exit blocks and any
95     // blocks properly dominating the header.
96     void ProcessLoop() {
97       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
98         (MachineBasicBlock *MBB) -> void {
99         if (Visited.count(MBB))
100           return;
101 
102         Visited.insert(MBB);
103         for (auto *Succ : MBB->successors()) {
104           if (!ML.contains(Succ))
105             continue;
106           Search(Succ);
107         }
108         Order.push_back(MBB);
109       };
110 
111       // Insert exit blocks.
112       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
113       ML.getExitBlocks(ExitBlocks);
114       for (auto *MBB : ExitBlocks)
115         Order.push_back(MBB);
116 
117       // Then add the loop body.
118       Search(ML.getHeader());
119 
120       // Then try the preheader and its predecessors.
121       std::function<void(MachineBasicBlock*)> GetPredecessor =
122         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
123         Order.push_back(MBB);
124         if (MBB->pred_size() == 1)
125           GetPredecessor(*MBB->pred_begin());
126       };
127 
128       if (auto *Preheader = ML.getLoopPreheader())
129         GetPredecessor(Preheader);
130       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
131         GetPredecessor(Preheader);
132     }
133   };
134 
135   struct PredicatedMI {
136     MachineInstr *MI = nullptr;
137     SetVector<MachineInstr*> Predicates;
138 
139   public:
140     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
141       assert(I && "Instruction must not be null!");
142       Predicates.insert(Preds.begin(), Preds.end());
143     }
144   };
145 
146   // Represent a VPT block, a list of instructions that begins with a VPT/VPST
147   // and has a maximum of four proceeding instructions. All instructions within
148   // the block are predicated upon the vpr and we allow instructions to define
149   // the vpr within in the block too.
150   class VPTBlock {
151     // The predicate then instruction, which is either a VPT, or a VPST
152     // instruction.
153     std::unique_ptr<PredicatedMI> PredicateThen;
154     PredicatedMI *Divergent = nullptr;
155     SmallVector<PredicatedMI, 4> Insts;
156 
157   public:
158     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
159       PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
160     }
161 
162     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
163       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
164       if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
165         Divergent = &Insts.back();
166         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
167       }
168       Insts.emplace_back(MI, Preds);
169       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
170     }
171 
172     // Have we found an instruction within the block which defines the vpr? If
173     // so, not all the instructions in the block will have the same predicate.
174     bool HasNonUniformPredicate() const {
175       return Divergent != nullptr;
176     }
177 
178     // Is the given instruction part of the predicate set controlling the entry
179     // to the block.
180     bool IsPredicatedOn(MachineInstr *MI) const {
181       return PredicateThen->Predicates.count(MI);
182     }
183 
184     // Returns true if this is a VPT instruction.
185     bool isVPT() const { return !isVPST(); }
186 
187     // Returns true if this is a VPST instruction.
188     bool isVPST() const {
189       return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
190     }
191 
192     // Is the given instruction the only predicate which controls the entry to
193     // the block.
194     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
195       return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
196     }
197 
198     unsigned size() const { return Insts.size(); }
199     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
200     MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
201     PredicatedMI *getDivergent() const { return Divergent; }
202   };
203 
204   struct LowOverheadLoop {
205 
206     MachineLoop &ML;
207     MachineBasicBlock *Preheader = nullptr;
208     MachineLoopInfo &MLI;
209     ReachingDefAnalysis &RDA;
210     const TargetRegisterInfo &TRI;
211     const ARMBaseInstrInfo &TII;
212     MachineFunction *MF = nullptr;
213     MachineInstr *InsertPt = nullptr;
214     MachineInstr *Start = nullptr;
215     MachineInstr *Dec = nullptr;
216     MachineInstr *End = nullptr;
217     MachineInstr *VCTP = nullptr;
218     MachineOperand TPNumElements;
219     SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
220     VPTBlock *CurrentBlock = nullptr;
221     SetVector<MachineInstr*> CurrentPredicate;
222     SmallVector<VPTBlock, 4> VPTBlocks;
223     SmallPtrSet<MachineInstr*, 4> ToRemove;
224     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
225     bool Revert = false;
226     bool CannotTailPredicate = false;
227 
228     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
229                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
230                     const ARMBaseInstrInfo &TII)
231         : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII),
232           TPNumElements(MachineOperand::CreateImm(0)) {
233       MF = ML.getHeader()->getParent();
234       if (auto *MBB = ML.getLoopPreheader())
235         Preheader = MBB;
236       else if (auto *MBB = MLI.findLoopPreheader(&ML, true))
237         Preheader = MBB;
238     }
239 
240     // If this is an MVE instruction, check that we know how to use tail
241     // predication with it. Record VPT blocks and return whether the
242     // instruction is valid for tail predication.
243     bool ValidateMVEInst(MachineInstr *MI);
244 
245     void AnalyseMVEInst(MachineInstr *MI) {
246       CannotTailPredicate = !ValidateMVEInst(MI);
247     }
248 
249     bool IsTailPredicationLegal() const {
250       // For now, let's keep things really simple and only support a single
251       // block for tail predication.
252       return !Revert && FoundAllComponents() && VCTP &&
253              !CannotTailPredicate && ML.getNumBlocks() == 1;
254     }
255 
256     // Check that the predication in the loop will be equivalent once we
257     // perform the conversion. Also ensure that we can provide the number
258     // of elements to the loop start instruction.
259     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
260 
261     // Check that any values available outside of the loop will be the same
262     // after tail predication conversion.
263     bool ValidateLiveOuts();
264 
265     // Is it safe to define LR with DLS/WLS?
266     // LR can be defined if it is the operand to start, because it's the same
267     // value, or if it's going to be equivalent to the operand to Start.
268     MachineInstr *isSafeToDefineLR();
269 
270     // Check the branch targets are within range and we satisfy our
271     // restrictions.
272     void CheckLegality(ARMBasicBlockUtils *BBUtils);
273 
274     bool FoundAllComponents() const {
275       return Start && Dec && End;
276     }
277 
278     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
279 
280     // Return the operand for the loop start instruction. This will be the loop
281     // iteration count, or the number of elements if we're tail predicating.
282     MachineOperand &getLoopStartOperand() {
283       return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0);
284     }
285 
286     unsigned getStartOpcode() const {
287       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
288       if (!IsTailPredicationLegal())
289         return IsDo ? ARM::t2DLS : ARM::t2WLS;
290 
291       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
292     }
293 
294     void dump() const {
295       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
296       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
297       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
298       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
299       if (!FoundAllComponents())
300         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
301       else if (!(Start && Dec && End))
302         dbgs() << "ARM Loops: Failed to find all loop components.\n";
303     }
304   };
305 
306   class ARMLowOverheadLoops : public MachineFunctionPass {
307     MachineFunction           *MF = nullptr;
308     MachineLoopInfo           *MLI = nullptr;
309     ReachingDefAnalysis       *RDA = nullptr;
310     const ARMBaseInstrInfo    *TII = nullptr;
311     MachineRegisterInfo       *MRI = nullptr;
312     const TargetRegisterInfo  *TRI = nullptr;
313     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
314 
315   public:
316     static char ID;
317 
318     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
319 
320     void getAnalysisUsage(AnalysisUsage &AU) const override {
321       AU.setPreservesCFG();
322       AU.addRequired<MachineLoopInfo>();
323       AU.addRequired<ReachingDefAnalysis>();
324       MachineFunctionPass::getAnalysisUsage(AU);
325     }
326 
327     bool runOnMachineFunction(MachineFunction &MF) override;
328 
329     MachineFunctionProperties getRequiredProperties() const override {
330       return MachineFunctionProperties().set(
331           MachineFunctionProperties::Property::NoVRegs).set(
332           MachineFunctionProperties::Property::TracksLiveness);
333     }
334 
335     StringRef getPassName() const override {
336       return ARM_LOW_OVERHEAD_LOOPS_NAME;
337     }
338 
339   private:
340     bool ProcessLoop(MachineLoop *ML);
341 
342     bool RevertNonLoops();
343 
344     void RevertWhile(MachineInstr *MI) const;
345 
346     bool RevertLoopDec(MachineInstr *MI) const;
347 
348     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
349 
350     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
351 
352     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
353 
354     void Expand(LowOverheadLoop &LoLoop);
355 
356     void IterationCountDCE(LowOverheadLoop &LoLoop);
357   };
358 }
359 
360 char ARMLowOverheadLoops::ID = 0;
361 
362 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
363                 false, false)
364 
365 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
366   // We can define LR because LR already contains the same value.
367   if (Start->getOperand(0).getReg() == ARM::LR)
368     return Start;
369 
370   unsigned CountReg = Start->getOperand(0).getReg();
371   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
372     return MI->getOpcode() == ARM::tMOVr &&
373            MI->getOperand(0).getReg() == ARM::LR &&
374            MI->getOperand(1).getReg() == CountReg &&
375            MI->getOperand(2).getImm() == ARMCC::AL;
376    };
377 
378   MachineBasicBlock *MBB = Start->getParent();
379 
380   // Find an insertion point:
381   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
382   //   to Count before Start, we can insert at that mov.
383   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
384     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
385       return LRDef;
386 
387   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
388   //   to Count after Start, we can insert at that mov.
389   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
390     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
391       return LRDef;
392 
393   // We've found no suitable LR def and Start doesn't use LR directly. Can we
394   // just define LR anyway?
395   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
396 }
397 
398 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
399   assert(VCTP && "VCTP instruction expected but is not set");
400   // All predication within the loop should be based on vctp. If the block
401   // isn't predicated on entry, check whether the vctp is within the block
402   // and that all other instructions are then predicated on it.
403   for (auto &Block : VPTBlocks) {
404     if (Block.IsPredicatedOn(VCTP))
405       continue;
406     if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
407       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
408                         << *Block.getDivergent()->MI);
409       return false;
410     }
411     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
412     for (auto &PredMI : Insts) {
413       // Check the instructions in the block and only allow:
414       //   - VCTPs
415       //   - Instructions predicated on the main VCTP
416       //   - Any VCMP
417       //      - VCMPs just "and" their result with VPR.P0. Whether they are
418       //      located before/after the VCTP is irrelevant - the end result will
419       //      be the same in both cases, so there's no point in requiring them
420       //      to be located after the VCTP!
421       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
422           VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
423         continue;
424       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
425                  << " - which is predicated on:\n";
426                  for (auto *MI : PredMI.Predicates)
427                    dbgs() << "   - " << *MI);
428       return false;
429     }
430   }
431 
432   if (!ValidateLiveOuts()) {
433     LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n");
434     return false;
435   }
436 
437   // For tail predication, we need to provide the number of elements, instead
438   // of the iteration count, to the loop start instruction. The number of
439   // elements is provided to the vctp instruction, so we need to check that
440   // we can use this register at InsertPt.
441   TPNumElements = VCTP->getOperand(1);
442   Register NumElements = TPNumElements.getReg();
443 
444   // If the register is defined within loop, then we can't perform TP.
445   // TODO: Check whether this is just a mov of a register that would be
446   // available.
447   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
448     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
449     return false;
450   }
451 
452   // The element count register maybe defined after InsertPt, in which case we
453   // need to try to move either InsertPt or the def so that the [w|d]lstp can
454   // use the value.
455   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
456 
457   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
458     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
459       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
460         ElemDef->removeFromParent();
461         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
462         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
463                    << *ElemDef);
464       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
465         StartInsertPt->removeFromParent();
466         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
467                               StartInsertPt);
468         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
469       } else {
470         // If we fail to move an instruction and the element count is provided
471         // by a mov, use the mov operand if it will have the same value at the
472         // insertion point
473         MachineOperand Operand = ElemDef->getOperand(1);
474         if (isMovRegOpcode(ElemDef->getOpcode()) &&
475             RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) ==
476                 RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) {
477           TPNumElements = Operand;
478           NumElements = TPNumElements.getReg();
479         } else {
480           LLVM_DEBUG(dbgs()
481                      << "ARM Loops: Unable to move element count to loop "
482                      << "start instruction.\n");
483           return false;
484         }
485       }
486     }
487   }
488 
489   // Especially in the case of while loops, InsertBB may not be the
490   // preheader, so we need to check that the register isn't redefined
491   // before entering the loop.
492   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
493                                       Register NumElements) {
494     // NumElements is redefined in this block.
495     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
496       return true;
497 
498     // Don't continue searching up through multiple predecessors.
499     if (MBB->pred_size() > 1)
500       return true;
501 
502     return false;
503   };
504 
505   // First, find the block that looks like the preheader.
506   MachineBasicBlock *MBB = Preheader;
507   if (!MBB) {
508     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
509     return false;
510   }
511 
512   // Then search backwards for a def, until we get to InsertBB.
513   while (MBB != InsertBB) {
514     if (CannotProvideElements(MBB, NumElements)) {
515       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
516       return false;
517     }
518     MBB = *MBB->pred_begin();
519   }
520 
521   // Check that the value change of the element count is what we expect and
522   // that the predication will be equivalent. For this we need:
523   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
524   // and we can also allow register copies within the chain too.
525   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
526     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
527   };
528 
529   MBB = VCTP->getParent();
530   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
531     SmallPtrSet<MachineInstr*, 2> ElementChain;
532     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
533     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
534 
535     Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
536 
537     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
538       bool FoundSub = false;
539 
540       for (auto *MI : ElementChain) {
541         if (isMovRegOpcode(MI->getOpcode()))
542           continue;
543 
544         if (isSubImmOpcode(MI->getOpcode())) {
545           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
546             return false;
547           FoundSub = true;
548         } else
549           return false;
550       }
551 
552       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
553                  for (auto *MI : ElementChain)
554                    dbgs() << " - " << *MI);
555       ToRemove.insert(ElementChain.begin(), ElementChain.end());
556     }
557   }
558   return true;
559 }
560 
561 static bool isVectorPredicated(MachineInstr *MI) {
562   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
563   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
564 }
565 
566 static bool isRegInClass(const MachineOperand &MO,
567                          const TargetRegisterClass *Class) {
568   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
569 }
570 
571 // MVE 'narrowing' operate on half a lane, reading from half and writing
572 // to half, which are referred to has the top and bottom half. The other
573 // half retains its previous value.
574 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
575   const MCInstrDesc &MCID = MI.getDesc();
576   uint64_t Flags = MCID.TSFlags;
577   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
578 }
579 
580 // Some MVE instructions read from the top/bottom halves of their operand(s)
581 // and generate a vector result with result elements that are double the
582 // width of the input.
583 static bool producesDoubleWidthResult(const MachineInstr &MI) {
584   const MCInstrDesc &MCID = MI.getDesc();
585   uint64_t Flags = MCID.TSFlags;
586   return (Flags & ARMII::DoubleWidthResult) != 0;
587 }
588 
589 static bool isHorizontalReduction(const MachineInstr &MI) {
590   const MCInstrDesc &MCID = MI.getDesc();
591   uint64_t Flags = MCID.TSFlags;
592   return (Flags & ARMII::HorizontalReduction) != 0;
593 }
594 
595 // Can this instruction generate a non-zero result when given only zeroed
596 // operands? This allows us to know that, given operands with false bytes
597 // zeroed by masked loads, that the result will also contain zeros in those
598 // bytes.
599 static bool canGenerateNonZeros(const MachineInstr &MI) {
600 
601   // Check for instructions which can write into a larger element size,
602   // possibly writing into a previous zero'd lane.
603   if (producesDoubleWidthResult(MI))
604     return true;
605 
606   switch (MI.getOpcode()) {
607   default:
608     break;
609   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
610   // fp16 -> fp32 vector conversions.
611   // Instructions that perform a NOT will generate 1s from 0s.
612   case ARM::MVE_VMVN:
613   case ARM::MVE_VORN:
614   // Count leading zeros will do just that!
615   case ARM::MVE_VCLZs8:
616   case ARM::MVE_VCLZs16:
617   case ARM::MVE_VCLZs32:
618     return true;
619   }
620   return false;
621 }
622 
623 // Look at its register uses to see if it only can only receive zeros
624 // into its false lanes which would then produce zeros. Also check that
625 // the output register is also defined by an FalseLanesZero instruction
626 // so that if tail-predication happens, the lanes that aren't updated will
627 // still be zeros.
628 static bool producesFalseLanesZero(MachineInstr &MI,
629                                    const TargetRegisterClass *QPRs,
630                                    const ReachingDefAnalysis &RDA,
631                                    InstSet &FalseLanesZero) {
632   if (canGenerateNonZeros(MI))
633     return false;
634 
635   bool isPredicated = isVectorPredicated(&MI);
636   // Predicated loads will write zeros to the falsely predicated bytes of the
637   // destination register.
638   if (MI.mayLoad())
639     return isPredicated;
640 
641   auto IsZeroInit = [](MachineInstr *Def) {
642     return !isVectorPredicated(Def) &&
643            Def->getOpcode() == ARM::MVE_VMOVimmi32 &&
644            Def->getOperand(1).getImm() == 0;
645   };
646 
647   bool AllowScalars = isHorizontalReduction(MI);
648   for (auto &MO : MI.operands()) {
649     if (!MO.isReg() || !MO.getReg())
650       continue;
651     if (!isRegInClass(MO, QPRs) && AllowScalars)
652       continue;
653 
654     // Check that this instruction will produce zeros in its false lanes:
655     // - If it only consumes false lanes zero or constant 0 (vmov #0)
656     // - If it's predicated, it only matters that it's def register already has
657     //   false lane zeros, so we can ignore the uses.
658     SmallPtrSet<MachineInstr *, 2> Defs;
659     RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs);
660     for (auto *Def : Defs) {
661       if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def))
662         continue;
663       if (MO.isUse() && isPredicated)
664         continue;
665       return false;
666     }
667   }
668   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
669   return true;
670 }
671 
672 bool LowOverheadLoop::ValidateLiveOuts() {
673   // We want to find out if the tail-predicated version of this loop will
674   // produce the same values as the loop in its original form. For this to
675   // be true, the newly inserted implicit predication must not change the
676   // the (observable) results.
677   // We're doing this because many instructions in the loop will not be
678   // predicated and so the conversion from VPT predication to tail-predication
679   // can result in different values being produced; due to the tail-predication
680   // preventing many instructions from updating their falsely predicated
681   // lanes. This analysis assumes that all the instructions perform lane-wise
682   // operations and don't perform any exchanges.
683   // A masked load, whether through VPT or tail predication, will write zeros
684   // to any of the falsely predicated bytes. So, from the loads, we know that
685   // the false lanes are zeroed and here we're trying to track that those false
686   // lanes remain zero, or where they change, the differences are masked away
687   // by their user(s).
688   // All MVE stores have to be predicated, so we know that any predicate load
689   // operands, or stored results are equivalent already. Other explicitly
690   // predicated instructions will perform the same operation in the original
691   // loop and the tail-predicated form too. Because of this, we can insert
692   // loads, stores and other predicated instructions into our Predicated
693   // set and build from there.
694   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
695   SetVector<MachineInstr *> FalseLanesUnknown;
696   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
697   SmallPtrSet<MachineInstr *, 4> Predicated;
698   MachineBasicBlock *Header = ML.getHeader();
699 
700   for (auto &MI : *Header) {
701     const MCInstrDesc &MCID = MI.getDesc();
702     uint64_t Flags = MCID.TSFlags;
703     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
704       continue;
705 
706     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
707       continue;
708 
709     bool isPredicated = isVectorPredicated(&MI);
710     bool retainsOrReduces =
711       retainsPreviousHalfElement(MI) || isHorizontalReduction(MI);
712 
713     if (isPredicated)
714       Predicated.insert(&MI);
715     if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero))
716       FalseLanesZero.insert(&MI);
717     else if (MI.getNumDefs() == 0)
718       continue;
719     else if (!isPredicated && retainsOrReduces)
720       return false;
721     else
722       FalseLanesUnknown.insert(&MI);
723   }
724 
725   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
726                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
727     SmallPtrSet<MachineInstr *, 2> Uses;
728     RDA.getGlobalUses(MI, MO.getReg(), Uses);
729     for (auto *Use : Uses) {
730       if (Use != MI && !Predicated.count(Use))
731         return false;
732     }
733     return true;
734   };
735 
736   // Visit the unknowns in reverse so that we can start at the values being
737   // stored and then we can work towards the leaves, hopefully adding more
738   // instructions to Predicated. Successfully terminating the loop means that
739   // all the unknown values have to found to be masked by predicated user(s).
740   // For any unpredicated values, we store them in NonPredicated so that we
741   // can later check whether these form a reduction.
742   SmallPtrSet<MachineInstr*, 2> NonPredicated;
743   for (auto *MI : reverse(FalseLanesUnknown)) {
744     for (auto &MO : MI->operands()) {
745       if (!isRegInClass(MO, QPRs) || !MO.isDef())
746         continue;
747       if (!HasPredicatedUsers(MI, MO, Predicated)) {
748         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
749                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
750         NonPredicated.insert(MI);
751         break;
752       }
753     }
754     // Any unknown false lanes have been masked away by the user(s).
755     if (!NonPredicated.contains(MI))
756       Predicated.insert(MI);
757   }
758 
759   SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
760   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
761   ML.getExitBlocks(ExitBlocks);
762   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
763   assert(ExitBlocks.size() == 1 && "Expected a single exit block");
764   MachineBasicBlock *ExitBB = ExitBlocks.front();
765   for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
766     // TODO: Instead of blocking predication, we could move the vctp to the exit
767     // block and calculate it's operand there in or the preheader.
768     if (RegMask.PhysReg == ARM::VPR)
769       return false;
770     // Check Q-regs that are live in the exit blocks. We don't collect scalars
771     // because they won't be affected by lane predication.
772     if (QPRs->contains(RegMask.PhysReg))
773       if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
774         LiveOutMIs.insert(MI);
775   }
776 
777   // We've already validated that any VPT predication within the loop will be
778   // equivalent when we perform the predication transformation; so we know that
779   // any VPT predicated instruction is predicated upon VCTP. Any live-out
780   // instruction needs to be predicated, so check this here. The instructions
781   // in NonPredicated have been found to be a reduction that we can ensure its
782   // legality.
783   for (auto *MI : LiveOutMIs) {
784     if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) {
785       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to handle live out: " << *MI);
786       return false;
787     }
788   }
789 
790   return true;
791 }
792 
793 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
794   if (Revert)
795     return;
796 
797   if (!End->getOperand(1).isMBB())
798     report_fatal_error("Expected LoopEnd to target basic block");
799 
800   // TODO Maybe there's cases where the target doesn't have to be the header,
801   // but for now be safe and revert.
802   if (End->getOperand(1).getMBB() != ML.getHeader()) {
803     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
804     Revert = true;
805     return;
806   }
807 
808   // The WLS and LE instructions have 12-bits for the label offset. WLS
809   // requires a positive offset, while LE uses negative.
810   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
811       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
812     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
813     Revert = true;
814     return;
815   }
816 
817   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
818       (BBUtils->getOffsetOf(Start) >
819        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
820        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
821     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
822     Revert = true;
823     return;
824   }
825 
826   InsertPt = Revert ? nullptr : isSafeToDefineLR();
827   if (!InsertPt) {
828     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
829     Revert = true;
830     return;
831   } else
832     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
833 
834   if (!IsTailPredicationLegal()) {
835     LLVM_DEBUG(if (!VCTP)
836                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
837                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
838     return;
839   }
840 
841   assert(ML.getBlocks().size() == 1 &&
842          "Shouldn't be processing a loop with more than one block");
843   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
844   LLVM_DEBUG(if (CannotTailPredicate)
845              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
846 }
847 
848 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
849   if (CannotTailPredicate)
850     return false;
851 
852   if (isVCTP(MI)) {
853     // If we find another VCTP, check whether it uses the same value as the main VCTP.
854     // If it does, store it in the SecondaryVCTPs set, else refuse it.
855     if (VCTP) {
856       if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
857           !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
858         LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
859                              "definition from the main VCTP");
860         return false;
861       }
862       LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
863       SecondaryVCTPs.insert(MI);
864     } else {
865       LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
866       VCTP = MI;
867     }
868   } else if (isVPTOpcode(MI->getOpcode())) {
869     if (MI->getOpcode() != ARM::MVE_VPST) {
870       assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
871              "VPT does not implicitly define VPR?!");
872       CurrentPredicate.insert(MI);
873     }
874 
875     VPTBlocks.emplace_back(MI, CurrentPredicate);
876     CurrentBlock = &VPTBlocks.back();
877     return true;
878   } else if (MI->getOpcode() == ARM::MVE_VPSEL ||
879              MI->getOpcode() == ARM::MVE_VPNOT) {
880     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
881     // 1) It will use the VPR as a predicate operand, but doesn't have to be
882     //    instead a VPT block, which means we can assert while building up
883     //    the VPT block because we don't find another VPT or VPST to being a new
884     //    one.
885     // 2) VPSEL still requires a VPR operand even after tail predicating,
886     //    which means we can't remove it unless there is another
887     //    instruction, such as vcmp, that can provide the VPR def.
888     return false;
889   }
890 
891   bool IsUse = false;
892   bool IsDef = false;
893   const MCInstrDesc &MCID = MI->getDesc();
894   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
895     const MachineOperand &MO = MI->getOperand(i);
896     if (!MO.isReg() || MO.getReg() != ARM::VPR)
897       continue;
898 
899     if (MO.isDef()) {
900       CurrentPredicate.insert(MI);
901       IsDef = true;
902     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
903       CurrentBlock->addInst(MI, CurrentPredicate);
904       IsUse = true;
905     } else {
906       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
907       return false;
908     }
909   }
910 
911   // If we find a vpr def that is not already predicated on the vctp, we've
912   // got disjoint predicates that may not be equivalent when we do the
913   // conversion.
914   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
915     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
916     return false;
917   }
918 
919   uint64_t Flags = MCID.TSFlags;
920   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
921     return true;
922 
923   // If we find an instruction that has been marked as not valid for tail
924   // predication, only allow the instruction if it's contained within a valid
925   // VPT block.
926   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
927     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
928     return false;
929   }
930 
931   // If the instruction is already explicitly predicated, then the conversion
932   // will be fine, but ensure that all store operations are predicated.
933   return !IsUse && MI->mayStore() ? false : true;
934 }
935 
936 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
937   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
938   if (!ST.hasLOB())
939     return false;
940 
941   MF = &mf;
942   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
943 
944   MLI = &getAnalysis<MachineLoopInfo>();
945   RDA = &getAnalysis<ReachingDefAnalysis>();
946   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
947   MRI = &MF->getRegInfo();
948   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
949   TRI = ST.getRegisterInfo();
950   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
951   BBUtils->computeAllBlockSizes();
952   BBUtils->adjustBBOffsetsAfter(&MF->front());
953 
954   bool Changed = false;
955   for (auto ML : *MLI) {
956     if (!ML->getParentLoop())
957       Changed |= ProcessLoop(ML);
958   }
959   Changed |= RevertNonLoops();
960   return Changed;
961 }
962 
963 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
964 
965   bool Changed = false;
966 
967   // Process inner loops first.
968   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
969     Changed |= ProcessLoop(*I);
970 
971   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
972              if (auto *Preheader = ML->getLoopPreheader())
973                dbgs() << " - " << Preheader->getName() << "\n";
974              else if (auto *Preheader = MLI->findLoopPreheader(ML))
975                dbgs() << " - " << Preheader->getName() << "\n";
976              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
977                dbgs() << " - " << Preheader->getName() << "\n";
978              for (auto *MBB : ML->getBlocks())
979                dbgs() << " - " << MBB->getName() << "\n";
980             );
981 
982   // Search the given block for a loop start instruction. If one isn't found,
983   // and there's only one predecessor block, search that one too.
984   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
985     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
986     for (auto &MI : *MBB) {
987       if (isLoopStart(MI))
988         return &MI;
989     }
990     if (MBB->pred_size() == 1)
991       return SearchForStart(*MBB->pred_begin());
992     return nullptr;
993   };
994 
995   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
996   // Search the preheader for the start intrinsic.
997   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
998   // with potentially multiple set.loop.iterations, so we need to enable this.
999   if (LoLoop.Preheader)
1000     LoLoop.Start = SearchForStart(LoLoop.Preheader);
1001   else
1002     return false;
1003 
1004   // Find the low-overhead loop components and decide whether or not to fall
1005   // back to a normal loop. Also look for a vctp instructions and decide
1006   // whether we can convert that predicate using tail predication.
1007   for (auto *MBB : reverse(ML->getBlocks())) {
1008     for (auto &MI : *MBB) {
1009       if (MI.isDebugValue())
1010         continue;
1011       else if (MI.getOpcode() == ARM::t2LoopDec)
1012         LoLoop.Dec = &MI;
1013       else if (MI.getOpcode() == ARM::t2LoopEnd)
1014         LoLoop.End = &MI;
1015       else if (isLoopStart(MI))
1016         LoLoop.Start = &MI;
1017       else if (MI.getDesc().isCall()) {
1018         // TODO: Though the call will require LE to execute again, does this
1019         // mean we should revert? Always executing LE hopefully should be
1020         // faster than performing a sub,cmp,br or even subs,br.
1021         LoLoop.Revert = true;
1022         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
1023       } else {
1024         // Record VPR defs and build up their corresponding vpt blocks.
1025         // Check we know how to tail predicate any mve instructions.
1026         LoLoop.AnalyseMVEInst(&MI);
1027       }
1028     }
1029   }
1030 
1031   LLVM_DEBUG(LoLoop.dump());
1032   if (!LoLoop.FoundAllComponents()) {
1033     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
1034     return false;
1035   }
1036 
1037   // Check that the only instruction using LoopDec is LoopEnd.
1038   // TODO: Check for copy chains that really have no effect.
1039   SmallPtrSet<MachineInstr*, 2> Uses;
1040   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
1041   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1042     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1043     LoLoop.Revert = true;
1044   }
1045   LoLoop.CheckLegality(BBUtils.get());
1046   Expand(LoLoop);
1047   return true;
1048 }
1049 
1050 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1051 // beq that branches to the exit branch.
1052 // TODO: We could also try to generate a cbz if the value in LR is also in
1053 // another low register.
1054 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1055   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1056   MachineBasicBlock *MBB = MI->getParent();
1057   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1058                                     TII->get(ARM::t2CMPri));
1059   MIB.add(MI->getOperand(0));
1060   MIB.addImm(0);
1061   MIB.addImm(ARMCC::AL);
1062   MIB.addReg(ARM::NoRegister);
1063 
1064   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1065   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1066     ARM::tBcc : ARM::t2Bcc;
1067 
1068   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1069   MIB.add(MI->getOperand(1));   // branch target
1070   MIB.addImm(ARMCC::EQ);        // condition code
1071   MIB.addReg(ARM::CPSR);
1072   MI->eraseFromParent();
1073 }
1074 
1075 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1076   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1077   MachineBasicBlock *MBB = MI->getParent();
1078   SmallPtrSet<MachineInstr*, 1> Ignore;
1079   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1080     if (I->getOpcode() == ARM::t2LoopEnd) {
1081       Ignore.insert(&*I);
1082       break;
1083     }
1084   }
1085 
1086   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1087   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1088 
1089   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1090                                     TII->get(ARM::t2SUBri));
1091   MIB.addDef(ARM::LR);
1092   MIB.add(MI->getOperand(1));
1093   MIB.add(MI->getOperand(2));
1094   MIB.addImm(ARMCC::AL);
1095   MIB.addReg(0);
1096 
1097   if (SetFlags) {
1098     MIB.addReg(ARM::CPSR);
1099     MIB->getOperand(5).setIsDef(true);
1100   } else
1101     MIB.addReg(0);
1102 
1103   MI->eraseFromParent();
1104   return SetFlags;
1105 }
1106 
1107 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1108 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1109   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1110 
1111   MachineBasicBlock *MBB = MI->getParent();
1112   // Create cmp
1113   if (!SkipCmp) {
1114     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1115                                       TII->get(ARM::t2CMPri));
1116     MIB.addReg(ARM::LR);
1117     MIB.addImm(0);
1118     MIB.addImm(ARMCC::AL);
1119     MIB.addReg(ARM::NoRegister);
1120   }
1121 
1122   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1123   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1124     ARM::tBcc : ARM::t2Bcc;
1125 
1126   // Create bne
1127   MachineInstrBuilder MIB =
1128     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1129   MIB.add(MI->getOperand(1));   // branch target
1130   MIB.addImm(ARMCC::NE);        // condition code
1131   MIB.addReg(ARM::CPSR);
1132   MI->eraseFromParent();
1133 }
1134 
1135 // Perform dead code elimation on the loop iteration count setup expression.
1136 // If we are tail-predicating, the number of elements to be processed is the
1137 // operand of the VCTP instruction in the vector body, see getCount(), which is
1138 // register $r3 in this example:
1139 //
1140 //   $lr = big-itercount-expression
1141 //   ..
1142 //   t2DoLoopStart renamable $lr
1143 //   vector.body:
1144 //     ..
1145 //     $vpr = MVE_VCTP32 renamable $r3
1146 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1147 //     t2LoopEnd renamable $lr, %vector.body
1148 //     tB %end
1149 //
1150 // What we would like achieve here is to replace the do-loop start pseudo
1151 // instruction t2DoLoopStart with:
1152 //
1153 //    $lr = MVE_DLSTP_32 killed renamable $r3
1154 //
1155 // Thus, $r3 which defines the number of elements, is written to $lr,
1156 // and then we want to delete the whole chain that used to define $lr,
1157 // see the comment below how this chain could look like.
1158 //
1159 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1160   if (!LoLoop.IsTailPredicationLegal())
1161     return;
1162 
1163   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1164 
1165   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1166   if (!Def) {
1167     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1168     return;
1169   }
1170 
1171   // Collect and remove the users of iteration count.
1172   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1173                                             LoLoop.End, LoLoop.InsertPt };
1174   SmallPtrSet<MachineInstr*, 2> Remove;
1175   if (RDA->isSafeToRemove(Def, Remove, Killed))
1176     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1177   else {
1178     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1179     return;
1180   }
1181 
1182   // Collect the dead code and the MBBs in which they reside.
1183   RDA->collectKilledOperands(Def, Killed);
1184   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1185   for (auto *MI : Killed)
1186     BasicBlocks.insert(MI->getParent());
1187 
1188   // Collect IT blocks in all affected basic blocks.
1189   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1190   for (auto *MBB : BasicBlocks) {
1191     for (auto &MI : *MBB) {
1192       if (MI.getOpcode() != ARM::t2IT)
1193         continue;
1194       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1195     }
1196   }
1197 
1198   // If we're removing all of the instructions within an IT block, then
1199   // also remove the IT instruction.
1200   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1201   for (auto *MI : Killed) {
1202     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1203       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1204       auto &CurrentBlock = ITBlocks[IT];
1205       CurrentBlock.erase(MI);
1206       if (CurrentBlock.empty())
1207         ModifiedITs.erase(IT);
1208       else
1209         ModifiedITs.insert(IT);
1210     }
1211   }
1212 
1213   // Delete the killed instructions only if we don't have any IT blocks that
1214   // need to be modified because we need to fixup the mask.
1215   // TODO: Handle cases where IT blocks are modified.
1216   if (ModifiedITs.empty()) {
1217     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1218                for (auto *MI : Killed)
1219                  dbgs() << " - " << *MI);
1220     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1221   } else
1222     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1223 }
1224 
1225 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1226   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1227   // When using tail-predication, try to delete the dead code that was used to
1228   // calculate the number of loop iterations.
1229   IterationCountDCE(LoLoop);
1230 
1231   MachineInstr *InsertPt = LoLoop.InsertPt;
1232   MachineInstr *Start = LoLoop.Start;
1233   MachineBasicBlock *MBB = InsertPt->getParent();
1234   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1235   unsigned Opc = LoLoop.getStartOpcode();
1236   MachineOperand &Count = LoLoop.getLoopStartOperand();
1237 
1238   MachineInstrBuilder MIB =
1239     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1240 
1241   MIB.addDef(ARM::LR);
1242   MIB.add(Count);
1243   if (!IsDo)
1244     MIB.add(Start->getOperand(1));
1245 
1246   // If we're inserting at a mov lr, then remove it as it's redundant.
1247   if (InsertPt != Start)
1248     LoLoop.ToRemove.insert(InsertPt);
1249   LoLoop.ToRemove.insert(Start);
1250   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1251   return &*MIB;
1252 }
1253 
1254 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1255   auto RemovePredicate = [](MachineInstr *MI) {
1256     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1257     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1258       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1259              "Expected Then predicate!");
1260       MI->getOperand(PIdx).setImm(ARMVCC::None);
1261       MI->getOperand(PIdx+1).setReg(0);
1262     } else
1263       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1264   };
1265 
1266   // There are a few scenarios which we have to fix up:
1267   // 1. VPT Blocks with non-uniform predicates:
1268   //    - a. When the divergent instruction is a vctp
1269   //    - b. When the block uses a vpst, and is only predicated on the vctp
1270   //    - c. When the block uses a vpt and (optionally) contains one or more
1271   //         vctp.
1272   // 2. VPT Blocks with uniform predicates:
1273   //    - a. The block uses a vpst, and is only predicated on the vctp
1274   for (auto &Block : LoLoop.getVPTBlocks()) {
1275     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1276     if (Block.HasNonUniformPredicate()) {
1277       PredicatedMI *Divergent = Block.getDivergent();
1278       if (isVCTP(Divergent->MI)) {
1279         // The vctp will be removed, so the block mask of the vp(s)t will need
1280         // to be recomputed.
1281         LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1282       } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1283         // The VPT block has a non-uniform predicate but it uses a vpst and its
1284         // entry is guarded only by a vctp, which means we:
1285         // - Need to remove the original vpst.
1286         // - Then need to unpredicate any following instructions, until
1287         //   we come across the divergent vpr def.
1288         // - Insert a new vpst to predicate the instruction(s) that following
1289         //   the divergent vpr def.
1290         // TODO: We could be producing more VPT blocks than necessary and could
1291         // fold the newly created one into a proceeding one.
1292         for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
1293              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1294           RemovePredicate(&*I);
1295 
1296         unsigned Size = 0;
1297         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1298         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1299         MachineInstr *InsertAt = nullptr;
1300         while (I != E) {
1301           InsertAt = &*I;
1302           ++Size;
1303           ++I;
1304         }
1305         // Create a VPST (with a null mask for now, we'll recompute it later).
1306         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1307                                           InsertAt->getDebugLoc(),
1308                                           TII->get(ARM::MVE_VPST));
1309         MIB.addImm(0);
1310         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1311         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1312         LoLoop.ToRemove.insert(Block.getPredicateThen());
1313         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1314       }
1315       // Else, if the block uses a vpt, iterate over the block, removing the
1316       // extra VCTPs it may contain.
1317       else if (Block.isVPT()) {
1318         bool RemovedVCTP = false;
1319         for (PredicatedMI &Elt : Block.getInsts()) {
1320           MachineInstr *MI = Elt.MI;
1321           if (isVCTP(MI)) {
1322             LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
1323             LoLoop.ToRemove.insert(MI);
1324             RemovedVCTP = true;
1325             continue;
1326           }
1327         }
1328         if (RemovedVCTP)
1329           LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1330       }
1331     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
1332       // A vpt block starting with VPST, is only predicated upon vctp and has no
1333       // internal vpr defs:
1334       // - Remove vpst.
1335       // - Unpredicate the remaining instructions.
1336       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1337       LoLoop.ToRemove.insert(Block.getPredicateThen());
1338       for (auto &PredMI : Insts)
1339         RemovePredicate(PredMI.MI);
1340     }
1341   }
1342   LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
1343   // Remove the "main" VCTP
1344   LoLoop.ToRemove.insert(LoLoop.VCTP);
1345   LLVM_DEBUG(dbgs() << "    " << *LoLoop.VCTP);
1346   // Remove remaining secondary VCTPs
1347   for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
1348     // All VCTPs that aren't marked for removal yet should be unpredicated ones.
1349     // The predicated ones should have already been marked for removal when
1350     // visiting the VPT blocks.
1351     if (LoLoop.ToRemove.insert(VCTP).second) {
1352       assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
1353              "Removing Predicated VCTP without updating the block mask!");
1354       LLVM_DEBUG(dbgs() << "    " << *VCTP);
1355     }
1356   }
1357 }
1358 
1359 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1360 
1361   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1362   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1363     MachineInstr *End = LoLoop.End;
1364     MachineBasicBlock *MBB = End->getParent();
1365     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1366       ARM::MVE_LETP : ARM::t2LEUpdate;
1367     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1368                                       TII->get(Opc));
1369     MIB.addDef(ARM::LR);
1370     MIB.add(End->getOperand(0));
1371     MIB.add(End->getOperand(1));
1372     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1373     LoLoop.ToRemove.insert(LoLoop.Dec);
1374     LoLoop.ToRemove.insert(End);
1375     return &*MIB;
1376   };
1377 
1378   // TODO: We should be able to automatically remove these branches before we
1379   // get here - probably by teaching analyzeBranch about the pseudo
1380   // instructions.
1381   // If there is an unconditional branch, after I, that just branches to the
1382   // next block, remove it.
1383   auto RemoveDeadBranch = [](MachineInstr *I) {
1384     MachineBasicBlock *BB = I->getParent();
1385     MachineInstr *Terminator = &BB->instr_back();
1386     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1387       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1388       if (BB->isLayoutSuccessor(Succ)) {
1389         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1390         Terminator->eraseFromParent();
1391       }
1392     }
1393   };
1394 
1395   if (LoLoop.Revert) {
1396     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1397       RevertWhile(LoLoop.Start);
1398     else
1399       LoLoop.Start->eraseFromParent();
1400     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1401     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1402   } else {
1403     LoLoop.Start = ExpandLoopStart(LoLoop);
1404     RemoveDeadBranch(LoLoop.Start);
1405     LoLoop.End = ExpandLoopEnd(LoLoop);
1406     RemoveDeadBranch(LoLoop.End);
1407     if (LoLoop.IsTailPredicationLegal())
1408       ConvertVPTBlocks(LoLoop);
1409     for (auto *I : LoLoop.ToRemove) {
1410       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1411       I->eraseFromParent();
1412     }
1413     for (auto *I : LoLoop.BlockMasksToRecompute) {
1414       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1415       recomputeVPTBlockMask(*I);
1416       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1417     }
1418   }
1419 
1420   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1421   DFS.ProcessLoop();
1422   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1423   for (auto *MBB : PostOrder) {
1424     recomputeLiveIns(*MBB);
1425     // FIXME: For some reason, the live-in print order is non-deterministic for
1426     // our tests and I can't out why... So just sort them.
1427     MBB->sortUniqueLiveIns();
1428   }
1429 
1430   for (auto *MBB : reverse(PostOrder))
1431     recomputeLivenessFlags(*MBB);
1432 
1433   // We've moved, removed and inserted new instructions, so update RDA.
1434   RDA->reset();
1435 }
1436 
1437 bool ARMLowOverheadLoops::RevertNonLoops() {
1438   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1439   bool Changed = false;
1440 
1441   for (auto &MBB : *MF) {
1442     SmallVector<MachineInstr*, 4> Starts;
1443     SmallVector<MachineInstr*, 4> Decs;
1444     SmallVector<MachineInstr*, 4> Ends;
1445 
1446     for (auto &I : MBB) {
1447       if (isLoopStart(I))
1448         Starts.push_back(&I);
1449       else if (I.getOpcode() == ARM::t2LoopDec)
1450         Decs.push_back(&I);
1451       else if (I.getOpcode() == ARM::t2LoopEnd)
1452         Ends.push_back(&I);
1453     }
1454 
1455     if (Starts.empty() && Decs.empty() && Ends.empty())
1456       continue;
1457 
1458     Changed = true;
1459 
1460     for (auto *Start : Starts) {
1461       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1462         RevertWhile(Start);
1463       else
1464         Start->eraseFromParent();
1465     }
1466     for (auto *Dec : Decs)
1467       RevertLoopDec(Dec);
1468 
1469     for (auto *End : Ends)
1470       RevertLoopEnd(End);
1471   }
1472   return Changed;
1473 }
1474 
1475 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1476   return new ARMLowOverheadLoops();
1477 }
1478