xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 3ce9ec0cfa9e3690df8a345636d6fa3e385610c3)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 /// A note on VPR.P0 (the lane mask):
39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a
40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks).
41 /// They will simply "and" the result of their calculation with the current
42 /// value of VPR.P0. You can think of it like this:
43 /// \verbatim
44 /// if VPT active:    ; Between a DLSTP/LETP, or for predicated instrs
45 ///   VPR.P0 &= Value
46 /// else
47 ///   VPR.P0 = Value
48 /// \endverbatim
49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always
50 /// fall in the "VPT active" case, so we can consider that all VPR writes by
51 /// one of those instruction is actually a "and".
52 //===----------------------------------------------------------------------===//
53 
54 #include "ARM.h"
55 #include "ARMBaseInstrInfo.h"
56 #include "ARMBaseRegisterInfo.h"
57 #include "ARMBasicBlockInfo.h"
58 #include "ARMSubtarget.h"
59 #include "Thumb2InstrInfo.h"
60 #include "llvm/ADT/SetOperations.h"
61 #include "llvm/ADT/SmallSet.h"
62 #include "llvm/CodeGen/LivePhysRegs.h"
63 #include "llvm/CodeGen/MachineFunctionPass.h"
64 #include "llvm/CodeGen/MachineLoopInfo.h"
65 #include "llvm/CodeGen/MachineLoopUtils.h"
66 #include "llvm/CodeGen/MachineRegisterInfo.h"
67 #include "llvm/CodeGen/Passes.h"
68 #include "llvm/CodeGen/ReachingDefAnalysis.h"
69 #include "llvm/MC/MCInstrDesc.h"
70 
71 using namespace llvm;
72 
73 #define DEBUG_TYPE "arm-low-overhead-loops"
74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
75 
76 namespace {
77 
78   using InstSet = SmallPtrSetImpl<MachineInstr *>;
79 
80   class PostOrderLoopTraversal {
81     MachineLoop &ML;
82     MachineLoopInfo &MLI;
83     SmallPtrSet<MachineBasicBlock*, 4> Visited;
84     SmallVector<MachineBasicBlock*, 4> Order;
85 
86   public:
87     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
88       : ML(ML), MLI(MLI) { }
89 
90     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
91       return Order;
92     }
93 
94     // Visit all the blocks within the loop, as well as exit blocks and any
95     // blocks properly dominating the header.
96     void ProcessLoop() {
97       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
98         (MachineBasicBlock *MBB) -> void {
99         if (Visited.count(MBB))
100           return;
101 
102         Visited.insert(MBB);
103         for (auto *Succ : MBB->successors()) {
104           if (!ML.contains(Succ))
105             continue;
106           Search(Succ);
107         }
108         Order.push_back(MBB);
109       };
110 
111       // Insert exit blocks.
112       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
113       ML.getExitBlocks(ExitBlocks);
114       for (auto *MBB : ExitBlocks)
115         Order.push_back(MBB);
116 
117       // Then add the loop body.
118       Search(ML.getHeader());
119 
120       // Then try the preheader and its predecessors.
121       std::function<void(MachineBasicBlock*)> GetPredecessor =
122         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
123         Order.push_back(MBB);
124         if (MBB->pred_size() == 1)
125           GetPredecessor(*MBB->pred_begin());
126       };
127 
128       if (auto *Preheader = ML.getLoopPreheader())
129         GetPredecessor(Preheader);
130       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
131         GetPredecessor(Preheader);
132     }
133   };
134 
135   struct PredicatedMI {
136     MachineInstr *MI = nullptr;
137     SetVector<MachineInstr*> Predicates;
138 
139   public:
140     PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) {
141       assert(I && "Instruction must not be null!");
142       Predicates.insert(Preds.begin(), Preds.end());
143     }
144   };
145 
146   // Represent a VPT block, a list of instructions that begins with a VPT/VPST
147   // and has a maximum of four proceeding instructions. All instructions within
148   // the block are predicated upon the vpr and we allow instructions to define
149   // the vpr within in the block too.
150   class VPTBlock {
151     // The predicate then instruction, which is either a VPT, or a VPST
152     // instruction.
153     std::unique_ptr<PredicatedMI> PredicateThen;
154     PredicatedMI *Divergent = nullptr;
155     SmallVector<PredicatedMI, 4> Insts;
156 
157   public:
158     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
159       PredicateThen = std::make_unique<PredicatedMI>(MI, Preds);
160     }
161 
162     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
163       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
164       if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) {
165         Divergent = &Insts.back();
166         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
167       }
168       Insts.emplace_back(MI, Preds);
169       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
170     }
171 
172     // Have we found an instruction within the block which defines the vpr? If
173     // so, not all the instructions in the block will have the same predicate.
174     bool HasNonUniformPredicate() const {
175       return Divergent != nullptr;
176     }
177 
178     // Is the given instruction part of the predicate set controlling the entry
179     // to the block.
180     bool IsPredicatedOn(MachineInstr *MI) const {
181       return PredicateThen->Predicates.count(MI);
182     }
183 
184     // Returns true if this is a VPT instruction.
185     bool isVPT() const { return !isVPST(); }
186 
187     // Returns true if this is a VPST instruction.
188     bool isVPST() const {
189       return PredicateThen->MI->getOpcode() == ARM::MVE_VPST;
190     }
191 
192     // Is the given instruction the only predicate which controls the entry to
193     // the block.
194     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
195       return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1;
196     }
197 
198     unsigned size() const { return Insts.size(); }
199     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
200     MachineInstr *getPredicateThen() const { return PredicateThen->MI; }
201     PredicatedMI *getDivergent() const { return Divergent; }
202   };
203 
204   struct LowOverheadLoop {
205 
206     MachineLoop &ML;
207     MachineBasicBlock *Preheader = nullptr;
208     MachineLoopInfo &MLI;
209     ReachingDefAnalysis &RDA;
210     const TargetRegisterInfo &TRI;
211     const ARMBaseInstrInfo &TII;
212     MachineFunction *MF = nullptr;
213     MachineInstr *InsertPt = nullptr;
214     MachineInstr *Start = nullptr;
215     MachineInstr *Dec = nullptr;
216     MachineInstr *End = nullptr;
217     MachineInstr *VCTP = nullptr;
218     MachineOperand TPNumElements;
219     SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs;
220     VPTBlock *CurrentBlock = nullptr;
221     SetVector<MachineInstr*> CurrentPredicate;
222     SmallVector<VPTBlock, 4> VPTBlocks;
223     SmallPtrSet<MachineInstr*, 4> ToRemove;
224     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
225     bool Revert = false;
226     bool CannotTailPredicate = false;
227 
228     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
229                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI,
230                     const ARMBaseInstrInfo &TII)
231         : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII),
232           TPNumElements(MachineOperand::CreateImm(0)) {
233       MF = ML.getHeader()->getParent();
234       if (auto *MBB = ML.getLoopPreheader())
235         Preheader = MBB;
236       else if (auto *MBB = MLI.findLoopPreheader(&ML, true))
237         Preheader = MBB;
238     }
239 
240     // If this is an MVE instruction, check that we know how to use tail
241     // predication with it. Record VPT blocks and return whether the
242     // instruction is valid for tail predication.
243     bool ValidateMVEInst(MachineInstr *MI);
244 
245     void AnalyseMVEInst(MachineInstr *MI) {
246       CannotTailPredicate = !ValidateMVEInst(MI);
247     }
248 
249     bool IsTailPredicationLegal() const {
250       // For now, let's keep things really simple and only support a single
251       // block for tail predication.
252       return !Revert && FoundAllComponents() && VCTP &&
253              !CannotTailPredicate && ML.getNumBlocks() == 1;
254     }
255 
256     // Check that the predication in the loop will be equivalent once we
257     // perform the conversion. Also ensure that we can provide the number
258     // of elements to the loop start instruction.
259     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
260 
261     // Check that any values available outside of the loop will be the same
262     // after tail predication conversion.
263     bool ValidateLiveOuts();
264 
265     // Is it safe to define LR with DLS/WLS?
266     // LR can be defined if it is the operand to start, because it's the same
267     // value, or if it's going to be equivalent to the operand to Start.
268     MachineInstr *isSafeToDefineLR();
269 
270     // Check the branch targets are within range and we satisfy our
271     // restrictions.
272     void CheckLegality(ARMBasicBlockUtils *BBUtils);
273 
274     bool FoundAllComponents() const {
275       return Start && Dec && End;
276     }
277 
278     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
279 
280     // Return the operand for the loop start instruction. This will be the loop
281     // iteration count, or the number of elements if we're tail predicating.
282     MachineOperand &getLoopStartOperand() {
283       return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0);
284     }
285 
286     unsigned getStartOpcode() const {
287       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
288       if (!IsTailPredicationLegal())
289         return IsDo ? ARM::t2DLS : ARM::t2WLS;
290 
291       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
292     }
293 
294     void dump() const {
295       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
296       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
297       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
298       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
299       if (!FoundAllComponents())
300         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
301       else if (!(Start && Dec && End))
302         dbgs() << "ARM Loops: Failed to find all loop components.\n";
303     }
304   };
305 
306   class ARMLowOverheadLoops : public MachineFunctionPass {
307     MachineFunction           *MF = nullptr;
308     MachineLoopInfo           *MLI = nullptr;
309     ReachingDefAnalysis       *RDA = nullptr;
310     const ARMBaseInstrInfo    *TII = nullptr;
311     MachineRegisterInfo       *MRI = nullptr;
312     const TargetRegisterInfo  *TRI = nullptr;
313     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
314 
315   public:
316     static char ID;
317 
318     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
319 
320     void getAnalysisUsage(AnalysisUsage &AU) const override {
321       AU.setPreservesCFG();
322       AU.addRequired<MachineLoopInfo>();
323       AU.addRequired<ReachingDefAnalysis>();
324       MachineFunctionPass::getAnalysisUsage(AU);
325     }
326 
327     bool runOnMachineFunction(MachineFunction &MF) override;
328 
329     MachineFunctionProperties getRequiredProperties() const override {
330       return MachineFunctionProperties().set(
331           MachineFunctionProperties::Property::NoVRegs).set(
332           MachineFunctionProperties::Property::TracksLiveness);
333     }
334 
335     StringRef getPassName() const override {
336       return ARM_LOW_OVERHEAD_LOOPS_NAME;
337     }
338 
339   private:
340     bool ProcessLoop(MachineLoop *ML);
341 
342     bool RevertNonLoops();
343 
344     void RevertWhile(MachineInstr *MI) const;
345 
346     bool RevertLoopDec(MachineInstr *MI) const;
347 
348     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
349 
350     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
351 
352     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
353 
354     void Expand(LowOverheadLoop &LoLoop);
355 
356     void IterationCountDCE(LowOverheadLoop &LoLoop);
357   };
358 }
359 
360 char ARMLowOverheadLoops::ID = 0;
361 
362 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
363                 false, false)
364 
365 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
366   // We can define LR because LR already contains the same value.
367   if (Start->getOperand(0).getReg() == ARM::LR)
368     return Start;
369 
370   unsigned CountReg = Start->getOperand(0).getReg();
371   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
372     return MI->getOpcode() == ARM::tMOVr &&
373            MI->getOperand(0).getReg() == ARM::LR &&
374            MI->getOperand(1).getReg() == CountReg &&
375            MI->getOperand(2).getImm() == ARMCC::AL;
376    };
377 
378   MachineBasicBlock *MBB = Start->getParent();
379 
380   // Find an insertion point:
381   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
382   //   to Count before Start, we can insert at that mov.
383   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
384     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
385       return LRDef;
386 
387   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
388   //   to Count after Start, we can insert at that mov.
389   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
390     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
391       return LRDef;
392 
393   // We've found no suitable LR def and Start doesn't use LR directly. Can we
394   // just define LR anyway?
395   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
396 }
397 
398 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
399   assert(VCTP && "VCTP instruction expected but is not set");
400   // All predication within the loop should be based on vctp. If the block
401   // isn't predicated on entry, check whether the vctp is within the block
402   // and that all other instructions are then predicated on it.
403   for (auto &Block : VPTBlocks) {
404     if (Block.IsPredicatedOn(VCTP))
405       continue;
406     if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) {
407       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
408                         << *Block.getDivergent()->MI);
409       return false;
410     }
411     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
412     for (auto &PredMI : Insts) {
413       // Check the instructions in the block and only allow:
414       //   - VCTPs
415       //   - Instructions predicated on the main VCTP
416       //   - Any VCMP
417       //      - VCMPs just "and" their result with VPR.P0. Whether they are
418       //      located before/after the VCTP is irrelevant - the end result will
419       //      be the same in both cases, so there's no point in requiring them
420       //      to be located after the VCTP!
421       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) ||
422           VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0)
423         continue;
424       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
425                  << " - which is predicated on:\n";
426                  for (auto *MI : PredMI.Predicates)
427                    dbgs() << "   - " << *MI);
428       return false;
429     }
430   }
431 
432   if (!ValidateLiveOuts()) {
433     LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n");
434     return false;
435   }
436 
437   // For tail predication, we need to provide the number of elements, instead
438   // of the iteration count, to the loop start instruction. The number of
439   // elements is provided to the vctp instruction, so we need to check that
440   // we can use this register at InsertPt.
441   TPNumElements = VCTP->getOperand(1);
442   Register NumElements = TPNumElements.getReg();
443 
444   // If the register is defined within loop, then we can't perform TP.
445   // TODO: Check whether this is just a mov of a register that would be
446   // available.
447   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
448     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
449     return false;
450   }
451 
452   // The element count register maybe defined after InsertPt, in which case we
453   // need to try to move either InsertPt or the def so that the [w|d]lstp can
454   // use the value.
455   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
456 
457   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
458     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
459       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
460         ElemDef->removeFromParent();
461         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
462         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
463                    << *ElemDef);
464       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
465         StartInsertPt->removeFromParent();
466         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
467                               StartInsertPt);
468         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
469       } else {
470         // If we fail to move an instruction and the element count is provided
471         // by a mov, use the mov operand if it will have the same value at the
472         // insertion point
473         MachineOperand Operand = ElemDef->getOperand(1);
474         if (isMovRegOpcode(ElemDef->getOpcode()) &&
475             RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) ==
476                 RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) {
477           TPNumElements = Operand;
478           NumElements = TPNumElements.getReg();
479         } else {
480           LLVM_DEBUG(dbgs()
481                      << "ARM Loops: Unable to move element count to loop "
482                      << "start instruction.\n");
483           return false;
484         }
485       }
486     }
487   }
488 
489   // Especially in the case of while loops, InsertBB may not be the
490   // preheader, so we need to check that the register isn't redefined
491   // before entering the loop.
492   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
493                                       Register NumElements) {
494     // NumElements is redefined in this block.
495     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
496       return true;
497 
498     // Don't continue searching up through multiple predecessors.
499     if (MBB->pred_size() > 1)
500       return true;
501 
502     return false;
503   };
504 
505   // First, find the block that looks like the preheader.
506   MachineBasicBlock *MBB = Preheader;
507   if (!MBB) {
508     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
509     return false;
510   }
511 
512   // Then search backwards for a def, until we get to InsertBB.
513   while (MBB != InsertBB) {
514     if (CannotProvideElements(MBB, NumElements)) {
515       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
516       return false;
517     }
518     MBB = *MBB->pred_begin();
519   }
520 
521   // Check that the value change of the element count is what we expect and
522   // that the predication will be equivalent. For this we need:
523   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
524   // and we can also allow register copies within the chain too.
525   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
526     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
527   };
528 
529   MBB = VCTP->getParent();
530   // Remove modifications to the element count since they have no purpose in a
531   // tail predicated loop. Explicitly refer to the vctp operand no matter which
532   // register NumElements has been assigned to, since that is what the
533   // modifications will be using
534   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(),
535                                              VCTP->getOperand(1).getReg())) {
536     SmallPtrSet<MachineInstr*, 2> ElementChain;
537     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
538     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
539 
540     Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end());
541 
542     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
543       bool FoundSub = false;
544 
545       for (auto *MI : ElementChain) {
546         if (isMovRegOpcode(MI->getOpcode()))
547           continue;
548 
549         if (isSubImmOpcode(MI->getOpcode())) {
550           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
551             return false;
552           FoundSub = true;
553         } else
554           return false;
555       }
556 
557       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
558                  for (auto *MI : ElementChain)
559                    dbgs() << " - " << *MI);
560       ToRemove.insert(ElementChain.begin(), ElementChain.end());
561     }
562   }
563   return true;
564 }
565 
566 static bool isVectorPredicated(MachineInstr *MI) {
567   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
568   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
569 }
570 
571 static bool isRegInClass(const MachineOperand &MO,
572                          const TargetRegisterClass *Class) {
573   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
574 }
575 
576 // MVE 'narrowing' operate on half a lane, reading from half and writing
577 // to half, which are referred to has the top and bottom half. The other
578 // half retains its previous value.
579 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
580   const MCInstrDesc &MCID = MI.getDesc();
581   uint64_t Flags = MCID.TSFlags;
582   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
583 }
584 
585 // Some MVE instructions read from the top/bottom halves of their operand(s)
586 // and generate a vector result with result elements that are double the
587 // width of the input.
588 static bool producesDoubleWidthResult(const MachineInstr &MI) {
589   const MCInstrDesc &MCID = MI.getDesc();
590   uint64_t Flags = MCID.TSFlags;
591   return (Flags & ARMII::DoubleWidthResult) != 0;
592 }
593 
594 static bool isHorizontalReduction(const MachineInstr &MI) {
595   const MCInstrDesc &MCID = MI.getDesc();
596   uint64_t Flags = MCID.TSFlags;
597   return (Flags & ARMII::HorizontalReduction) != 0;
598 }
599 
600 // Can this instruction generate a non-zero result when given only zeroed
601 // operands? This allows us to know that, given operands with false bytes
602 // zeroed by masked loads, that the result will also contain zeros in those
603 // bytes.
604 static bool canGenerateNonZeros(const MachineInstr &MI) {
605 
606   // Check for instructions which can write into a larger element size,
607   // possibly writing into a previous zero'd lane.
608   if (producesDoubleWidthResult(MI))
609     return true;
610 
611   switch (MI.getOpcode()) {
612   default:
613     break;
614   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
615   // fp16 -> fp32 vector conversions.
616   // Instructions that perform a NOT will generate 1s from 0s.
617   case ARM::MVE_VMVN:
618   case ARM::MVE_VORN:
619   // Count leading zeros will do just that!
620   case ARM::MVE_VCLZs8:
621   case ARM::MVE_VCLZs16:
622   case ARM::MVE_VCLZs32:
623     return true;
624   }
625   return false;
626 }
627 
628 // Look at its register uses to see if it only can only receive zeros
629 // into its false lanes which would then produce zeros. Also check that
630 // the output register is also defined by an FalseLanesZero instruction
631 // so that if tail-predication happens, the lanes that aren't updated will
632 // still be zeros.
633 static bool producesFalseLanesZero(MachineInstr &MI,
634                                    const TargetRegisterClass *QPRs,
635                                    const ReachingDefAnalysis &RDA,
636                                    InstSet &FalseLanesZero) {
637   if (canGenerateNonZeros(MI))
638     return false;
639 
640   bool isPredicated = isVectorPredicated(&MI);
641   // Predicated loads will write zeros to the falsely predicated bytes of the
642   // destination register.
643   if (MI.mayLoad())
644     return isPredicated;
645 
646   auto IsZeroInit = [](MachineInstr *Def) {
647     return !isVectorPredicated(Def) &&
648            Def->getOpcode() == ARM::MVE_VMOVimmi32 &&
649            Def->getOperand(1).getImm() == 0;
650   };
651 
652   bool AllowScalars = isHorizontalReduction(MI);
653   for (auto &MO : MI.operands()) {
654     if (!MO.isReg() || !MO.getReg())
655       continue;
656     if (!isRegInClass(MO, QPRs) && AllowScalars)
657       continue;
658 
659     // Check that this instruction will produce zeros in its false lanes:
660     // - If it only consumes false lanes zero or constant 0 (vmov #0)
661     // - If it's predicated, it only matters that it's def register already has
662     //   false lane zeros, so we can ignore the uses.
663     SmallPtrSet<MachineInstr *, 2> Defs;
664     RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs);
665     for (auto *Def : Defs) {
666       if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def))
667         continue;
668       if (MO.isUse() && isPredicated)
669         continue;
670       return false;
671     }
672   }
673   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
674   return true;
675 }
676 
677 bool LowOverheadLoop::ValidateLiveOuts() {
678   // We want to find out if the tail-predicated version of this loop will
679   // produce the same values as the loop in its original form. For this to
680   // be true, the newly inserted implicit predication must not change the
681   // the (observable) results.
682   // We're doing this because many instructions in the loop will not be
683   // predicated and so the conversion from VPT predication to tail-predication
684   // can result in different values being produced; due to the tail-predication
685   // preventing many instructions from updating their falsely predicated
686   // lanes. This analysis assumes that all the instructions perform lane-wise
687   // operations and don't perform any exchanges.
688   // A masked load, whether through VPT or tail predication, will write zeros
689   // to any of the falsely predicated bytes. So, from the loads, we know that
690   // the false lanes are zeroed and here we're trying to track that those false
691   // lanes remain zero, or where they change, the differences are masked away
692   // by their user(s).
693   // All MVE stores have to be predicated, so we know that any predicate load
694   // operands, or stored results are equivalent already. Other explicitly
695   // predicated instructions will perform the same operation in the original
696   // loop and the tail-predicated form too. Because of this, we can insert
697   // loads, stores and other predicated instructions into our Predicated
698   // set and build from there.
699   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
700   SetVector<MachineInstr *> FalseLanesUnknown;
701   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
702   SmallPtrSet<MachineInstr *, 4> Predicated;
703   MachineBasicBlock *Header = ML.getHeader();
704 
705   for (auto &MI : *Header) {
706     const MCInstrDesc &MCID = MI.getDesc();
707     uint64_t Flags = MCID.TSFlags;
708     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
709       continue;
710 
711     if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode()))
712       continue;
713 
714     bool isPredicated = isVectorPredicated(&MI);
715     bool retainsOrReduces =
716       retainsPreviousHalfElement(MI) || isHorizontalReduction(MI);
717 
718     if (isPredicated)
719       Predicated.insert(&MI);
720     if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero))
721       FalseLanesZero.insert(&MI);
722     else if (MI.getNumDefs() == 0)
723       continue;
724     else if (!isPredicated && retainsOrReduces)
725       return false;
726     else if (!isPredicated)
727       FalseLanesUnknown.insert(&MI);
728   }
729 
730   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
731                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
732     SmallPtrSet<MachineInstr *, 2> Uses;
733     RDA.getGlobalUses(MI, MO.getReg(), Uses);
734     for (auto *Use : Uses) {
735       if (Use != MI && !Predicated.count(Use))
736         return false;
737     }
738     return true;
739   };
740 
741   // Visit the unknowns in reverse so that we can start at the values being
742   // stored and then we can work towards the leaves, hopefully adding more
743   // instructions to Predicated. Successfully terminating the loop means that
744   // all the unknown values have to found to be masked by predicated user(s).
745   // For any unpredicated values, we store them in NonPredicated so that we
746   // can later check whether these form a reduction.
747   SmallPtrSet<MachineInstr*, 2> NonPredicated;
748   for (auto *MI : reverse(FalseLanesUnknown)) {
749     for (auto &MO : MI->operands()) {
750       if (!isRegInClass(MO, QPRs) || !MO.isDef())
751         continue;
752       if (!HasPredicatedUsers(MI, MO, Predicated)) {
753         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
754                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
755         NonPredicated.insert(MI);
756         break;
757       }
758     }
759     // Any unknown false lanes have been masked away by the user(s).
760     if (!NonPredicated.contains(MI))
761       Predicated.insert(MI);
762   }
763 
764   SmallPtrSet<MachineInstr *, 2> LiveOutMIs;
765   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
766   ML.getExitBlocks(ExitBlocks);
767   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
768   assert(ExitBlocks.size() == 1 && "Expected a single exit block");
769   MachineBasicBlock *ExitBB = ExitBlocks.front();
770   for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) {
771     // TODO: Instead of blocking predication, we could move the vctp to the exit
772     // block and calculate it's operand there in or the preheader.
773     if (RegMask.PhysReg == ARM::VPR)
774       return false;
775     // Check Q-regs that are live in the exit blocks. We don't collect scalars
776     // because they won't be affected by lane predication.
777     if (QPRs->contains(RegMask.PhysReg))
778       if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg))
779         LiveOutMIs.insert(MI);
780   }
781 
782   // We've already validated that any VPT predication within the loop will be
783   // equivalent when we perform the predication transformation; so we know that
784   // any VPT predicated instruction is predicated upon VCTP. Any live-out
785   // instruction needs to be predicated, so check this here. The instructions
786   // in NonPredicated have been found to be a reduction that we can ensure its
787   // legality.
788   for (auto *MI : LiveOutMIs) {
789     if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) {
790       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to handle live out: " << *MI);
791       return false;
792     }
793   }
794 
795   return true;
796 }
797 
798 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
799   if (Revert)
800     return;
801 
802   if (!End->getOperand(1).isMBB())
803     report_fatal_error("Expected LoopEnd to target basic block");
804 
805   // TODO Maybe there's cases where the target doesn't have to be the header,
806   // but for now be safe and revert.
807   if (End->getOperand(1).getMBB() != ML.getHeader()) {
808     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
809     Revert = true;
810     return;
811   }
812 
813   // The WLS and LE instructions have 12-bits for the label offset. WLS
814   // requires a positive offset, while LE uses negative.
815   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
816       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
817     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
818     Revert = true;
819     return;
820   }
821 
822   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
823       (BBUtils->getOffsetOf(Start) >
824        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
825        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
826     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
827     Revert = true;
828     return;
829   }
830 
831   InsertPt = Revert ? nullptr : isSafeToDefineLR();
832   if (!InsertPt) {
833     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
834     Revert = true;
835     return;
836   } else
837     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
838 
839   if (!IsTailPredicationLegal()) {
840     LLVM_DEBUG(if (!VCTP)
841                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
842                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
843     return;
844   }
845 
846   assert(ML.getBlocks().size() == 1 &&
847          "Shouldn't be processing a loop with more than one block");
848   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
849   LLVM_DEBUG(if (CannotTailPredicate)
850              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
851 }
852 
853 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
854   if (CannotTailPredicate)
855     return false;
856 
857   const MCInstrDesc &MCID = MI->getDesc();
858   uint64_t Flags = MCID.TSFlags;
859   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
860     return true;
861 
862   if (MI->getOpcode() == ARM::MVE_VPSEL ||
863       MI->getOpcode() == ARM::MVE_VPNOT) {
864     // TODO: Allow VPSEL and VPNOT, we currently cannot because:
865     // 1) It will use the VPR as a predicate operand, but doesn't have to be
866     //    instead a VPT block, which means we can assert while building up
867     //    the VPT block because we don't find another VPT or VPST to being a new
868     //    one.
869     // 2) VPSEL still requires a VPR operand even after tail predicating,
870     //    which means we can't remove it unless there is another
871     //    instruction, such as vcmp, that can provide the VPR def.
872     return false;
873   }
874 
875   if (isVCTP(MI)) {
876     // If we find another VCTP, check whether it uses the same value as the main VCTP.
877     // If it does, store it in the SecondaryVCTPs set, else refuse it.
878     if (VCTP) {
879       if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) ||
880           !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) {
881         LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching "
882                              "definition from the main VCTP");
883         return false;
884       }
885       LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI);
886       SecondaryVCTPs.insert(MI);
887     } else {
888       LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI);
889       VCTP = MI;
890     }
891   } else if (isVPTOpcode(MI->getOpcode())) {
892     if (MI->getOpcode() != ARM::MVE_VPST) {
893       assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 &&
894              "VPT does not implicitly define VPR?!");
895       CurrentPredicate.clear();
896       CurrentPredicate.insert(MI);
897     }
898 
899     VPTBlocks.emplace_back(MI, CurrentPredicate);
900     CurrentBlock = &VPTBlocks.back();
901     return true;
902   }
903 
904   bool IsUse = false;
905   bool IsDef = false;
906   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
907     const MachineOperand &MO = MI->getOperand(i);
908     if (!MO.isReg() || MO.getReg() != ARM::VPR)
909       continue;
910 
911     if (MO.isDef()) {
912       CurrentPredicate.insert(MI);
913       IsDef = true;
914     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
915       CurrentBlock->addInst(MI, CurrentPredicate);
916       IsUse = true;
917     } else {
918       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
919       return false;
920     }
921   }
922 
923   // If this instruction defines the VPR, update the predicate for the
924   // proceeding instructions.
925   if (IsDef) {
926     // Clear the existing predicate when we're not in VPT Active state.
927     if (!isVectorPredicated(MI))
928       CurrentPredicate.clear();
929     CurrentPredicate.insert(MI);
930     LLVM_DEBUG(dbgs() << "ARM Loops: Adding Predicate: " << *MI);
931   }
932 
933   // If we find a vpr def that is not already predicated on the vctp, we've
934   // got disjoint predicates that may not be equivalent when we do the
935   // conversion.
936   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
937     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
938     return false;
939   }
940 
941   // If we find an instruction that has been marked as not valid for tail
942   // predication, only allow the instruction if it's contained within a valid
943   // VPT block.
944   if ((Flags & ARMII::ValidForTailPredication) == 0) {
945     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
946     return IsUse;
947   }
948 
949   // If the instruction is already explicitly predicated, then the conversion
950   // will be fine, but ensure that all store operations are predicated.
951   return !IsUse && MI->mayStore() ? false : true;
952 }
953 
954 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
955   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
956   if (!ST.hasLOB())
957     return false;
958 
959   MF = &mf;
960   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
961 
962   MLI = &getAnalysis<MachineLoopInfo>();
963   RDA = &getAnalysis<ReachingDefAnalysis>();
964   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
965   MRI = &MF->getRegInfo();
966   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
967   TRI = ST.getRegisterInfo();
968   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
969   BBUtils->computeAllBlockSizes();
970   BBUtils->adjustBBOffsetsAfter(&MF->front());
971 
972   bool Changed = false;
973   for (auto ML : *MLI) {
974     if (!ML->getParentLoop())
975       Changed |= ProcessLoop(ML);
976   }
977   Changed |= RevertNonLoops();
978   return Changed;
979 }
980 
981 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
982 
983   bool Changed = false;
984 
985   // Process inner loops first.
986   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
987     Changed |= ProcessLoop(*I);
988 
989   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
990              if (auto *Preheader = ML->getLoopPreheader())
991                dbgs() << " - " << Preheader->getName() << "\n";
992              else if (auto *Preheader = MLI->findLoopPreheader(ML))
993                dbgs() << " - " << Preheader->getName() << "\n";
994              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
995                dbgs() << " - " << Preheader->getName() << "\n";
996              for (auto *MBB : ML->getBlocks())
997                dbgs() << " - " << MBB->getName() << "\n";
998             );
999 
1000   // Search the given block for a loop start instruction. If one isn't found,
1001   // and there's only one predecessor block, search that one too.
1002   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
1003     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
1004     for (auto &MI : *MBB) {
1005       if (isLoopStart(MI))
1006         return &MI;
1007     }
1008     if (MBB->pred_size() == 1)
1009       return SearchForStart(*MBB->pred_begin());
1010     return nullptr;
1011   };
1012 
1013   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII);
1014   // Search the preheader for the start intrinsic.
1015   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
1016   // with potentially multiple set.loop.iterations, so we need to enable this.
1017   if (LoLoop.Preheader)
1018     LoLoop.Start = SearchForStart(LoLoop.Preheader);
1019   else
1020     return false;
1021 
1022   // Find the low-overhead loop components and decide whether or not to fall
1023   // back to a normal loop. Also look for a vctp instructions and decide
1024   // whether we can convert that predicate using tail predication.
1025   for (auto *MBB : reverse(ML->getBlocks())) {
1026     for (auto &MI : *MBB) {
1027       if (MI.isDebugValue())
1028         continue;
1029       else if (MI.getOpcode() == ARM::t2LoopDec)
1030         LoLoop.Dec = &MI;
1031       else if (MI.getOpcode() == ARM::t2LoopEnd)
1032         LoLoop.End = &MI;
1033       else if (isLoopStart(MI))
1034         LoLoop.Start = &MI;
1035       else if (MI.getDesc().isCall()) {
1036         // TODO: Though the call will require LE to execute again, does this
1037         // mean we should revert? Always executing LE hopefully should be
1038         // faster than performing a sub,cmp,br or even subs,br.
1039         LoLoop.Revert = true;
1040         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
1041       } else {
1042         // Record VPR defs and build up their corresponding vpt blocks.
1043         // Check we know how to tail predicate any mve instructions.
1044         LoLoop.AnalyseMVEInst(&MI);
1045       }
1046     }
1047   }
1048 
1049   LLVM_DEBUG(LoLoop.dump());
1050   if (!LoLoop.FoundAllComponents()) {
1051     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
1052     return false;
1053   }
1054 
1055   // Check that the only instruction using LoopDec is LoopEnd.
1056   // TODO: Check for copy chains that really have no effect.
1057   SmallPtrSet<MachineInstr*, 2> Uses;
1058   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
1059   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
1060     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
1061     LoLoop.Revert = true;
1062   }
1063   LoLoop.CheckLegality(BBUtils.get());
1064   Expand(LoLoop);
1065   return true;
1066 }
1067 
1068 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
1069 // beq that branches to the exit branch.
1070 // TODO: We could also try to generate a cbz if the value in LR is also in
1071 // another low register.
1072 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
1073   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
1074   MachineBasicBlock *MBB = MI->getParent();
1075   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1076                                     TII->get(ARM::t2CMPri));
1077   MIB.add(MI->getOperand(0));
1078   MIB.addImm(0);
1079   MIB.addImm(ARMCC::AL);
1080   MIB.addReg(ARM::NoRegister);
1081 
1082   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1083   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1084     ARM::tBcc : ARM::t2Bcc;
1085 
1086   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1087   MIB.add(MI->getOperand(1));   // branch target
1088   MIB.addImm(ARMCC::EQ);        // condition code
1089   MIB.addReg(ARM::CPSR);
1090   MI->eraseFromParent();
1091 }
1092 
1093 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
1094   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
1095   MachineBasicBlock *MBB = MI->getParent();
1096   SmallPtrSet<MachineInstr*, 1> Ignore;
1097   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
1098     if (I->getOpcode() == ARM::t2LoopEnd) {
1099       Ignore.insert(&*I);
1100       break;
1101     }
1102   }
1103 
1104   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
1105   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
1106 
1107   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1108                                     TII->get(ARM::t2SUBri));
1109   MIB.addDef(ARM::LR);
1110   MIB.add(MI->getOperand(1));
1111   MIB.add(MI->getOperand(2));
1112   MIB.addImm(ARMCC::AL);
1113   MIB.addReg(0);
1114 
1115   if (SetFlags) {
1116     MIB.addReg(ARM::CPSR);
1117     MIB->getOperand(5).setIsDef(true);
1118   } else
1119     MIB.addReg(0);
1120 
1121   MI->eraseFromParent();
1122   return SetFlags;
1123 }
1124 
1125 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1126 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1127   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1128 
1129   MachineBasicBlock *MBB = MI->getParent();
1130   // Create cmp
1131   if (!SkipCmp) {
1132     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1133                                       TII->get(ARM::t2CMPri));
1134     MIB.addReg(ARM::LR);
1135     MIB.addImm(0);
1136     MIB.addImm(ARMCC::AL);
1137     MIB.addReg(ARM::NoRegister);
1138   }
1139 
1140   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1141   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1142     ARM::tBcc : ARM::t2Bcc;
1143 
1144   // Create bne
1145   MachineInstrBuilder MIB =
1146     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1147   MIB.add(MI->getOperand(1));   // branch target
1148   MIB.addImm(ARMCC::NE);        // condition code
1149   MIB.addReg(ARM::CPSR);
1150   MI->eraseFromParent();
1151 }
1152 
1153 // Perform dead code elimation on the loop iteration count setup expression.
1154 // If we are tail-predicating, the number of elements to be processed is the
1155 // operand of the VCTP instruction in the vector body, see getCount(), which is
1156 // register $r3 in this example:
1157 //
1158 //   $lr = big-itercount-expression
1159 //   ..
1160 //   t2DoLoopStart renamable $lr
1161 //   vector.body:
1162 //     ..
1163 //     $vpr = MVE_VCTP32 renamable $r3
1164 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1165 //     t2LoopEnd renamable $lr, %vector.body
1166 //     tB %end
1167 //
1168 // What we would like achieve here is to replace the do-loop start pseudo
1169 // instruction t2DoLoopStart with:
1170 //
1171 //    $lr = MVE_DLSTP_32 killed renamable $r3
1172 //
1173 // Thus, $r3 which defines the number of elements, is written to $lr,
1174 // and then we want to delete the whole chain that used to define $lr,
1175 // see the comment below how this chain could look like.
1176 //
1177 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1178   if (!LoLoop.IsTailPredicationLegal())
1179     return;
1180 
1181   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1182 
1183   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1184   if (!Def) {
1185     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1186     return;
1187   }
1188 
1189   // Collect and remove the users of iteration count.
1190   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1191                                             LoLoop.End, LoLoop.InsertPt };
1192   SmallPtrSet<MachineInstr*, 2> Remove;
1193   if (RDA->isSafeToRemove(Def, Remove, Killed))
1194     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1195   else {
1196     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1197     return;
1198   }
1199 
1200   // Collect the dead code and the MBBs in which they reside.
1201   RDA->collectKilledOperands(Def, Killed);
1202   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1203   for (auto *MI : Killed)
1204     BasicBlocks.insert(MI->getParent());
1205 
1206   // Collect IT blocks in all affected basic blocks.
1207   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1208   for (auto *MBB : BasicBlocks) {
1209     for (auto &MI : *MBB) {
1210       if (MI.getOpcode() != ARM::t2IT)
1211         continue;
1212       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1213     }
1214   }
1215 
1216   // If we're removing all of the instructions within an IT block, then
1217   // also remove the IT instruction.
1218   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1219   for (auto *MI : Killed) {
1220     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1221       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1222       auto &CurrentBlock = ITBlocks[IT];
1223       CurrentBlock.erase(MI);
1224       if (CurrentBlock.empty())
1225         ModifiedITs.erase(IT);
1226       else
1227         ModifiedITs.insert(IT);
1228     }
1229   }
1230 
1231   // Delete the killed instructions only if we don't have any IT blocks that
1232   // need to be modified because we need to fixup the mask.
1233   // TODO: Handle cases where IT blocks are modified.
1234   if (ModifiedITs.empty()) {
1235     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1236                for (auto *MI : Killed)
1237                  dbgs() << " - " << *MI);
1238     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1239   } else
1240     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1241 }
1242 
1243 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1244   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1245   // When using tail-predication, try to delete the dead code that was used to
1246   // calculate the number of loop iterations.
1247   IterationCountDCE(LoLoop);
1248 
1249   MachineInstr *InsertPt = LoLoop.InsertPt;
1250   MachineInstr *Start = LoLoop.Start;
1251   MachineBasicBlock *MBB = InsertPt->getParent();
1252   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1253   unsigned Opc = LoLoop.getStartOpcode();
1254   MachineOperand &Count = LoLoop.getLoopStartOperand();
1255 
1256   MachineInstrBuilder MIB =
1257     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1258 
1259   MIB.addDef(ARM::LR);
1260   MIB.add(Count);
1261   if (!IsDo)
1262     MIB.add(Start->getOperand(1));
1263 
1264   // If we're inserting at a mov lr, then remove it as it's redundant.
1265   if (InsertPt != Start)
1266     LoLoop.ToRemove.insert(InsertPt);
1267   LoLoop.ToRemove.insert(Start);
1268   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1269   return &*MIB;
1270 }
1271 
1272 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1273   auto RemovePredicate = [](MachineInstr *MI) {
1274     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1275     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1276       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1277              "Expected Then predicate!");
1278       MI->getOperand(PIdx).setImm(ARMVCC::None);
1279       MI->getOperand(PIdx+1).setReg(0);
1280     } else
1281       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1282   };
1283 
1284   // There are a few scenarios which we have to fix up:
1285   // 1. VPT Blocks with non-uniform predicates:
1286   //    - a. When the divergent instruction is a vctp
1287   //    - b. When the block uses a vpst, and is only predicated on the vctp
1288   //    - c. When the block uses a vpt and (optionally) contains one or more
1289   //         vctp.
1290   // 2. VPT Blocks with uniform predicates:
1291   //    - a. The block uses a vpst, and is only predicated on the vctp
1292   for (auto &Block : LoLoop.getVPTBlocks()) {
1293     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1294     if (Block.HasNonUniformPredicate()) {
1295       PredicatedMI *Divergent = Block.getDivergent();
1296       if (isVCTP(Divergent->MI)) {
1297         // The vctp will be removed, so the block mask of the vp(s)t will need
1298         // to be recomputed.
1299         LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1300       } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1301         // The VPT block has a non-uniform predicate but it uses a vpst and its
1302         // entry is guarded only by a vctp, which means we:
1303         // - Need to remove the original vpst.
1304         // - Then need to unpredicate any following instructions, until
1305         //   we come across the divergent vpr def.
1306         // - Insert a new vpst to predicate the instruction(s) that following
1307         //   the divergent vpr def.
1308         // TODO: We could be producing more VPT blocks than necessary and could
1309         // fold the newly created one into a proceeding one.
1310         for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()),
1311              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1312           RemovePredicate(&*I);
1313 
1314         // Check if the instruction defining vpr is a vcmp so it can be combined
1315         // with the VPST This should be the divergent instruction
1316         MachineInstr *VCMP = VCMPOpcodeToVPT(Divergent->MI->getOpcode()) != 0
1317                                  ? Divergent->MI
1318                                  : nullptr;
1319 
1320         unsigned Size = 0;
1321         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1322         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1323         MachineInstr *InsertAt = nullptr;
1324         while (I != E) {
1325           InsertAt = &*I;
1326           ++Size;
1327           ++I;
1328         }
1329         MachineInstrBuilder MIB;
1330         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: "
1331                           << *Block.getPredicateThen());
1332         if (VCMP) {
1333           // Combine the VPST and VCMP into a VPT
1334           MIB =
1335               BuildMI(*InsertAt->getParent(), InsertAt, InsertAt->getDebugLoc(),
1336                       TII->get(VCMPOpcodeToVPT(VCMP->getOpcode())));
1337           MIB.addImm(ARMVCC::Then);
1338           // Register one
1339           MIB.add(VCMP->getOperand(1));
1340           // Register two
1341           MIB.add(VCMP->getOperand(2));
1342           // The comparison code, e.g. ge, eq, lt
1343           MIB.add(VCMP->getOperand(3));
1344           LLVM_DEBUG(dbgs()
1345                      << "ARM Loops: Combining with VCMP to VPT: " << *MIB);
1346           LoLoop.ToRemove.insert(VCMP);
1347         } else {
1348           // Create a VPST (with a null mask for now, we'll recompute it later)
1349           // or a VPT in case there was a VCMP right before it
1350           MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1351                         InsertAt->getDebugLoc(), TII->get(ARM::MVE_VPST));
1352           MIB.addImm(0);
1353           LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1354         }
1355         LoLoop.ToRemove.insert(Block.getPredicateThen());
1356         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1357       }
1358       // Else, if the block uses a vpt, iterate over the block, removing the
1359       // extra VCTPs it may contain.
1360       else if (Block.isVPT()) {
1361         bool RemovedVCTP = false;
1362         for (PredicatedMI &Elt : Block.getInsts()) {
1363           MachineInstr *MI = Elt.MI;
1364           if (isVCTP(MI)) {
1365             LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI);
1366             LoLoop.ToRemove.insert(MI);
1367             RemovedVCTP = true;
1368             continue;
1369           }
1370         }
1371         if (RemovedVCTP)
1372           LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen());
1373       }
1374     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) {
1375       // A vpt block starting with VPST, is only predicated upon vctp and has no
1376       // internal vpr defs:
1377       // - Remove vpst.
1378       // - Unpredicate the remaining instructions.
1379       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen());
1380       LoLoop.ToRemove.insert(Block.getPredicateThen());
1381       for (auto &PredMI : Insts)
1382         RemovePredicate(PredMI.MI);
1383     }
1384   }
1385   LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n");
1386   // Remove the "main" VCTP
1387   LoLoop.ToRemove.insert(LoLoop.VCTP);
1388   LLVM_DEBUG(dbgs() << "    " << *LoLoop.VCTP);
1389   // Remove remaining secondary VCTPs
1390   for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) {
1391     // All VCTPs that aren't marked for removal yet should be unpredicated ones.
1392     // The predicated ones should have already been marked for removal when
1393     // visiting the VPT blocks.
1394     if (LoLoop.ToRemove.insert(VCTP).second) {
1395       assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None &&
1396              "Removing Predicated VCTP without updating the block mask!");
1397       LLVM_DEBUG(dbgs() << "    " << *VCTP);
1398     }
1399   }
1400 }
1401 
1402 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1403 
1404   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1405   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1406     MachineInstr *End = LoLoop.End;
1407     MachineBasicBlock *MBB = End->getParent();
1408     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1409       ARM::MVE_LETP : ARM::t2LEUpdate;
1410     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1411                                       TII->get(Opc));
1412     MIB.addDef(ARM::LR);
1413     MIB.add(End->getOperand(0));
1414     MIB.add(End->getOperand(1));
1415     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1416     LoLoop.ToRemove.insert(LoLoop.Dec);
1417     LoLoop.ToRemove.insert(End);
1418     return &*MIB;
1419   };
1420 
1421   // TODO: We should be able to automatically remove these branches before we
1422   // get here - probably by teaching analyzeBranch about the pseudo
1423   // instructions.
1424   // If there is an unconditional branch, after I, that just branches to the
1425   // next block, remove it.
1426   auto RemoveDeadBranch = [](MachineInstr *I) {
1427     MachineBasicBlock *BB = I->getParent();
1428     MachineInstr *Terminator = &BB->instr_back();
1429     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1430       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1431       if (BB->isLayoutSuccessor(Succ)) {
1432         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1433         Terminator->eraseFromParent();
1434       }
1435     }
1436   };
1437 
1438   if (LoLoop.Revert) {
1439     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1440       RevertWhile(LoLoop.Start);
1441     else
1442       LoLoop.Start->eraseFromParent();
1443     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1444     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1445   } else {
1446     LoLoop.Start = ExpandLoopStart(LoLoop);
1447     RemoveDeadBranch(LoLoop.Start);
1448     LoLoop.End = ExpandLoopEnd(LoLoop);
1449     RemoveDeadBranch(LoLoop.End);
1450     if (LoLoop.IsTailPredicationLegal())
1451       ConvertVPTBlocks(LoLoop);
1452     for (auto *I : LoLoop.ToRemove) {
1453       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1454       I->eraseFromParent();
1455     }
1456     for (auto *I : LoLoop.BlockMasksToRecompute) {
1457       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1458       recomputeVPTBlockMask(*I);
1459       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1460     }
1461   }
1462 
1463   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1464   DFS.ProcessLoop();
1465   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1466   for (auto *MBB : PostOrder) {
1467     recomputeLiveIns(*MBB);
1468     // FIXME: For some reason, the live-in print order is non-deterministic for
1469     // our tests and I can't out why... So just sort them.
1470     MBB->sortUniqueLiveIns();
1471   }
1472 
1473   for (auto *MBB : reverse(PostOrder))
1474     recomputeLivenessFlags(*MBB);
1475 
1476   // We've moved, removed and inserted new instructions, so update RDA.
1477   RDA->reset();
1478 }
1479 
1480 bool ARMLowOverheadLoops::RevertNonLoops() {
1481   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1482   bool Changed = false;
1483 
1484   for (auto &MBB : *MF) {
1485     SmallVector<MachineInstr*, 4> Starts;
1486     SmallVector<MachineInstr*, 4> Decs;
1487     SmallVector<MachineInstr*, 4> Ends;
1488 
1489     for (auto &I : MBB) {
1490       if (isLoopStart(I))
1491         Starts.push_back(&I);
1492       else if (I.getOpcode() == ARM::t2LoopDec)
1493         Decs.push_back(&I);
1494       else if (I.getOpcode() == ARM::t2LoopEnd)
1495         Ends.push_back(&I);
1496     }
1497 
1498     if (Starts.empty() && Decs.empty() && Ends.empty())
1499       continue;
1500 
1501     Changed = true;
1502 
1503     for (auto *Start : Starts) {
1504       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1505         RevertWhile(Start);
1506       else
1507         Start->eraseFromParent();
1508     }
1509     for (auto *Dec : Decs)
1510       RevertLoopDec(Dec);
1511 
1512     for (auto *End : Ends)
1513       RevertLoopEnd(End);
1514   }
1515   return Changed;
1516 }
1517 
1518 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1519   return new ARMLowOverheadLoops();
1520 }
1521