xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 24bf8063d677f261f26d8771180cc08d51007a2e)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #include "ARM.h"
41 #include "ARMBaseInstrInfo.h"
42 #include "ARMBaseRegisterInfo.h"
43 #include "ARMBasicBlockInfo.h"
44 #include "ARMSubtarget.h"
45 #include "Thumb2InstrInfo.h"
46 #include "llvm/ADT/SetOperations.h"
47 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/LivePhysRegs.h"
49 #include "llvm/CodeGen/MachineFunctionPass.h"
50 #include "llvm/CodeGen/MachineLoopInfo.h"
51 #include "llvm/CodeGen/MachineLoopUtils.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/CodeGen/Passes.h"
54 #include "llvm/CodeGen/ReachingDefAnalysis.h"
55 #include "llvm/MC/MCInstrDesc.h"
56 
57 using namespace llvm;
58 
59 #define DEBUG_TYPE "arm-low-overhead-loops"
60 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
61 
62 namespace {
63 
64   using InstSet = SmallPtrSetImpl<MachineInstr *>;
65 
66   class PostOrderLoopTraversal {
67     MachineLoop &ML;
68     MachineLoopInfo &MLI;
69     SmallPtrSet<MachineBasicBlock*, 4> Visited;
70     SmallVector<MachineBasicBlock*, 4> Order;
71 
72   public:
73     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
74       : ML(ML), MLI(MLI) { }
75 
76     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
77       return Order;
78     }
79 
80     // Visit all the blocks within the loop, as well as exit blocks and any
81     // blocks properly dominating the header.
82     void ProcessLoop() {
83       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
84         (MachineBasicBlock *MBB) -> void {
85         if (Visited.count(MBB))
86           return;
87 
88         Visited.insert(MBB);
89         for (auto *Succ : MBB->successors()) {
90           if (!ML.contains(Succ))
91             continue;
92           Search(Succ);
93         }
94         Order.push_back(MBB);
95       };
96 
97       // Insert exit blocks.
98       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
99       ML.getExitBlocks(ExitBlocks);
100       for (auto *MBB : ExitBlocks)
101         Order.push_back(MBB);
102 
103       // Then add the loop body.
104       Search(ML.getHeader());
105 
106       // Then try the preheader and its predecessors.
107       std::function<void(MachineBasicBlock*)> GetPredecessor =
108         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
109         Order.push_back(MBB);
110         if (MBB->pred_size() == 1)
111           GetPredecessor(*MBB->pred_begin());
112       };
113 
114       if (auto *Preheader = ML.getLoopPreheader())
115         GetPredecessor(Preheader);
116       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
117         GetPredecessor(Preheader);
118     }
119   };
120 
121   struct PredicatedMI {
122     MachineInstr *MI = nullptr;
123     SetVector<MachineInstr*> Predicates;
124 
125   public:
126     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
127       MI(I) { Predicates.insert(Preds.begin(), Preds.end()); }
128   };
129 
130   // Represent a VPT block, a list of instructions that begins with a VPST and
131   // has a maximum of four proceeding instructions. All instructions within the
132   // block are predicated upon the vpr and we allow instructions to define the
133   // vpr within in the block too.
134   class VPTBlock {
135     std::unique_ptr<PredicatedMI> VPST;
136     PredicatedMI *Divergent = nullptr;
137     SmallVector<PredicatedMI, 4> Insts;
138 
139   public:
140     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
141       VPST = std::make_unique<PredicatedMI>(MI, Preds);
142     }
143 
144     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
145       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
146       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
147         Divergent = &Insts.back();
148         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
149       }
150       Insts.emplace_back(MI, Preds);
151       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
152     }
153 
154     // Have we found an instruction within the block which defines the vpr? If
155     // so, not all the instructions in the block will have the same predicate.
156     bool HasNonUniformPredicate() const {
157       return Divergent != nullptr;
158     }
159 
160     // Is the given instruction part of the predicate set controlling the entry
161     // to the block.
162     bool IsPredicatedOn(MachineInstr *MI) const {
163       return VPST->Predicates.count(MI);
164     }
165 
166     // Is the given instruction the only predicate which controls the entry to
167     // the block.
168     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
169       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
170     }
171 
172     unsigned size() const { return Insts.size(); }
173     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
174     MachineInstr *getVPST() const { return VPST->MI; }
175     PredicatedMI *getDivergent() const { return Divergent; }
176   };
177 
178   struct LowOverheadLoop {
179 
180     MachineLoop &ML;
181     MachineLoopInfo &MLI;
182     ReachingDefAnalysis &RDA;
183     const TargetRegisterInfo &TRI;
184     MachineFunction *MF = nullptr;
185     MachineInstr *InsertPt = nullptr;
186     MachineInstr *Start = nullptr;
187     MachineInstr *Dec = nullptr;
188     MachineInstr *End = nullptr;
189     MachineInstr *VCTP = nullptr;
190     VPTBlock *CurrentBlock = nullptr;
191     SetVector<MachineInstr*> CurrentPredicate;
192     SmallVector<VPTBlock, 4> VPTBlocks;
193     SmallPtrSet<MachineInstr*, 4> ToRemove;
194     SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute;
195     bool Revert = false;
196     bool CannotTailPredicate = false;
197 
198     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
199                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
200       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
201       MF = ML.getHeader()->getParent();
202     }
203 
204     // If this is an MVE instruction, check that we know how to use tail
205     // predication with it. Record VPT blocks and return whether the
206     // instruction is valid for tail predication.
207     bool ValidateMVEInst(MachineInstr *MI);
208 
209     void AnalyseMVEInst(MachineInstr *MI) {
210       CannotTailPredicate = !ValidateMVEInst(MI);
211     }
212 
213     bool IsTailPredicationLegal() const {
214       // For now, let's keep things really simple and only support a single
215       // block for tail predication.
216       return !Revert && FoundAllComponents() && VCTP &&
217              !CannotTailPredicate && ML.getNumBlocks() == 1;
218     }
219 
220     // Check that the predication in the loop will be equivalent once we
221     // perform the conversion. Also ensure that we can provide the number
222     // of elements to the loop start instruction.
223     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
224 
225     // Check that any values available outside of the loop will be the same
226     // after tail predication conversion.
227     bool ValidateLiveOuts() const;
228 
229     // Is it safe to define LR with DLS/WLS?
230     // LR can be defined if it is the operand to start, because it's the same
231     // value, or if it's going to be equivalent to the operand to Start.
232     MachineInstr *isSafeToDefineLR();
233 
234     // Check the branch targets are within range and we satisfy our
235     // restrictions.
236     void CheckLegality(ARMBasicBlockUtils *BBUtils);
237 
238     bool FoundAllComponents() const {
239       return Start && Dec && End;
240     }
241 
242     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
243 
244     // Return the loop iteration count, or the number of elements if we're tail
245     // predicating.
246     MachineOperand &getCount() {
247       return IsTailPredicationLegal() ?
248         VCTP->getOperand(1) : Start->getOperand(0);
249     }
250 
251     unsigned getStartOpcode() const {
252       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
253       if (!IsTailPredicationLegal())
254         return IsDo ? ARM::t2DLS : ARM::t2WLS;
255 
256       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
257     }
258 
259     void dump() const {
260       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
261       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
262       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
263       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
264       if (!FoundAllComponents())
265         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
266       else if (!(Start && Dec && End))
267         dbgs() << "ARM Loops: Failed to find all loop components.\n";
268     }
269   };
270 
271   class ARMLowOverheadLoops : public MachineFunctionPass {
272     MachineFunction           *MF = nullptr;
273     MachineLoopInfo           *MLI = nullptr;
274     ReachingDefAnalysis       *RDA = nullptr;
275     const ARMBaseInstrInfo    *TII = nullptr;
276     MachineRegisterInfo       *MRI = nullptr;
277     const TargetRegisterInfo  *TRI = nullptr;
278     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
279 
280   public:
281     static char ID;
282 
283     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
284 
285     void getAnalysisUsage(AnalysisUsage &AU) const override {
286       AU.setPreservesCFG();
287       AU.addRequired<MachineLoopInfo>();
288       AU.addRequired<ReachingDefAnalysis>();
289       MachineFunctionPass::getAnalysisUsage(AU);
290     }
291 
292     bool runOnMachineFunction(MachineFunction &MF) override;
293 
294     MachineFunctionProperties getRequiredProperties() const override {
295       return MachineFunctionProperties().set(
296           MachineFunctionProperties::Property::NoVRegs).set(
297           MachineFunctionProperties::Property::TracksLiveness);
298     }
299 
300     StringRef getPassName() const override {
301       return ARM_LOW_OVERHEAD_LOOPS_NAME;
302     }
303 
304   private:
305     bool ProcessLoop(MachineLoop *ML);
306 
307     bool RevertNonLoops();
308 
309     void RevertWhile(MachineInstr *MI) const;
310 
311     bool RevertLoopDec(MachineInstr *MI) const;
312 
313     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
314 
315     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
316 
317     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
318 
319     void Expand(LowOverheadLoop &LoLoop);
320 
321     void IterationCountDCE(LowOverheadLoop &LoLoop);
322   };
323 }
324 
325 char ARMLowOverheadLoops::ID = 0;
326 
327 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
328                 false, false)
329 
330 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
331   // We can define LR because LR already contains the same value.
332   if (Start->getOperand(0).getReg() == ARM::LR)
333     return Start;
334 
335   unsigned CountReg = Start->getOperand(0).getReg();
336   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
337     return MI->getOpcode() == ARM::tMOVr &&
338            MI->getOperand(0).getReg() == ARM::LR &&
339            MI->getOperand(1).getReg() == CountReg &&
340            MI->getOperand(2).getImm() == ARMCC::AL;
341    };
342 
343   MachineBasicBlock *MBB = Start->getParent();
344 
345   // Find an insertion point:
346   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
347   //   to Count before Start, we can insert at that mov.
348   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
349     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
350       return LRDef;
351 
352   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
353   //   to Count after Start, we can insert at that mov.
354   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
355     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
356       return LRDef;
357 
358   // We've found no suitable LR def and Start doesn't use LR directly. Can we
359   // just define LR anyway?
360   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
361 }
362 
363 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
364   assert(VCTP && "VCTP instruction expected but is not set");
365   // All predication within the loop should be based on vctp. If the block
366   // isn't predicated on entry, check whether the vctp is within the block
367   // and that all other instructions are then predicated on it.
368   for (auto &Block : VPTBlocks) {
369     if (Block.IsPredicatedOn(VCTP))
370       continue;
371     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
372       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
373                  << *Block.getDivergent()->MI);
374       return false;
375     }
376     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
377     for (auto &PredMI : Insts) {
378       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
379         continue;
380       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
381                  << " - which is predicated on:\n";
382                  for (auto *MI : PredMI.Predicates)
383                    dbgs() << "   - " << *MI);
384       return false;
385     }
386   }
387 
388   if (!ValidateLiveOuts())
389     return false;
390 
391   // For tail predication, we need to provide the number of elements, instead
392   // of the iteration count, to the loop start instruction. The number of
393   // elements is provided to the vctp instruction, so we need to check that
394   // we can use this register at InsertPt.
395   Register NumElements = VCTP->getOperand(1).getReg();
396 
397   // If the register is defined within loop, then we can't perform TP.
398   // TODO: Check whether this is just a mov of a register that would be
399   // available.
400   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
401     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
402     return false;
403   }
404 
405   // The element count register maybe defined after InsertPt, in which case we
406   // need to try to move either InsertPt or the def so that the [w|d]lstp can
407   // use the value.
408   // TODO: On failing to move an instruction, check if the count is provided by
409   // a mov and whether we can use the mov operand directly.
410   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
411   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
412     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
413       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
414         ElemDef->removeFromParent();
415         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
416         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
417                    << *ElemDef);
418       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
419         StartInsertPt->removeFromParent();
420         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
421                               StartInsertPt);
422         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
423       } else {
424         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
425                    << "start instruction.\n");
426         return false;
427       }
428     }
429   }
430 
431   // Especially in the case of while loops, InsertBB may not be the
432   // preheader, so we need to check that the register isn't redefined
433   // before entering the loop.
434   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
435                                       Register NumElements) {
436     // NumElements is redefined in this block.
437     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
438       return true;
439 
440     // Don't continue searching up through multiple predecessors.
441     if (MBB->pred_size() > 1)
442       return true;
443 
444     return false;
445   };
446 
447   // First, find the block that looks like the preheader.
448   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
449   if (!MBB) {
450     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
451     return false;
452   }
453 
454   // Then search backwards for a def, until we get to InsertBB.
455   while (MBB != InsertBB) {
456     if (CannotProvideElements(MBB, NumElements)) {
457       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
458       return false;
459     }
460     MBB = *MBB->pred_begin();
461   }
462 
463   // Check that the value change of the element count is what we expect and
464   // that the predication will be equivalent. For this we need:
465   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
466   // and we can also allow register copies within the chain too.
467   auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) {
468     return -getAddSubImmediate(*MI) == ExpectedVecWidth;
469   };
470 
471   MBB = VCTP->getParent();
472   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
473     SmallPtrSet<MachineInstr*, 2> ElementChain;
474     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
475     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
476 
477     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
478       bool FoundSub = false;
479 
480       for (auto *MI : ElementChain) {
481         if (isMovRegOpcode(MI->getOpcode()))
482           continue;
483 
484         if (isSubImmOpcode(MI->getOpcode())) {
485           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
486             return false;
487           FoundSub = true;
488         } else
489           return false;
490       }
491 
492       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
493                  for (auto *MI : ElementChain)
494                    dbgs() << " - " << *MI);
495       ToRemove.insert(ElementChain.begin(), ElementChain.end());
496     }
497   }
498   return true;
499 }
500 
501 static bool isVectorPredicated(MachineInstr *MI) {
502   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
503   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
504 }
505 
506 static bool isRegInClass(const MachineOperand &MO,
507                          const TargetRegisterClass *Class) {
508   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
509 }
510 
511 // MVE 'narrowing' operate on half a lane, reading from half and writing
512 // to half, which are referred to has the top and bottom half. The other
513 // half retains its previous value.
514 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
515   const MCInstrDesc &MCID = MI.getDesc();
516   uint64_t Flags = MCID.TSFlags;
517   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
518 }
519 
520 // Some MVE instructions read from the top/bottom halves of their operand(s)
521 // and generate a vector result with result elements that are double the
522 // width of the input.
523 static bool producesDoubleWidthResult(const MachineInstr &MI) {
524   const MCInstrDesc &MCID = MI.getDesc();
525   uint64_t Flags = MCID.TSFlags;
526   return (Flags & ARMII::DoubleWidthResult) != 0;
527 }
528 
529 static bool isHorizontalReduction(const MachineInstr &MI) {
530   const MCInstrDesc &MCID = MI.getDesc();
531   uint64_t Flags = MCID.TSFlags;
532   return (Flags & ARMII::HorizontalReduction) != 0;
533 }
534 
535 // Can this instruction generate a non-zero result when given only zeroed
536 // operands? This allows us to know that, given operands with false bytes
537 // zeroed by masked loads, that the result will also contain zeros in those
538 // bytes.
539 static bool canGenerateNonZeros(const MachineInstr &MI) {
540 
541   // Check for instructions which can write into a larger element size,
542   // possibly writing into a previous zero'd lane.
543   if (producesDoubleWidthResult(MI))
544     return true;
545 
546   switch (MI.getOpcode()) {
547   default:
548     break;
549   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
550   // fp16 -> fp32 vector conversions.
551   // Instructions that perform a NOT will generate 1s from 0s.
552   case ARM::MVE_VMVN:
553   case ARM::MVE_VORN:
554   // Count leading zeros will do just that!
555   case ARM::MVE_VCLZs8:
556   case ARM::MVE_VCLZs16:
557   case ARM::MVE_VCLZs32:
558     return true;
559   }
560   return false;
561 }
562 
563 
564 // Look at its register uses to see if it only can only receive zeros
565 // into its false lanes which would then produce zeros. Also check that
566 // the output register is also defined by an FalseLanesZero instruction
567 // so that if tail-predication happens, the lanes that aren't updated will
568 // still be zeros.
569 static bool producesFalseLanesZero(MachineInstr &MI,
570                                    const TargetRegisterClass *QPRs,
571                                    const ReachingDefAnalysis &RDA,
572                                    InstSet &FalseLanesZero) {
573   if (canGenerateNonZeros(MI))
574     return false;
575 
576   bool AllowScalars = isHorizontalReduction(MI);
577   for (auto &MO : MI.operands()) {
578     if (!MO.isReg() || !MO.getReg())
579       continue;
580     if (!isRegInClass(MO, QPRs) && AllowScalars)
581       continue;
582     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
583       if (FalseLanesZero.count(OpDef))
584        continue;
585     return false;
586   }
587   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
588   return true;
589 }
590 
591 bool LowOverheadLoop::ValidateLiveOuts() const {
592   // We want to find out if the tail-predicated version of this loop will
593   // produce the same values as the loop in its original form. For this to
594   // be true, the newly inserted implicit predication must not change the
595   // the (observable) results.
596   // We're doing this because many instructions in the loop will not be
597   // predicated and so the conversion from VPT predication to tail-predication
598   // can result in different values being produced; due to the tail-predication
599   // preventing many instructions from updating their falsely predicated
600   // lanes. This analysis assumes that all the instructions perform lane-wise
601   // operations and don't perform any exchanges.
602   // A masked load, whether through VPT or tail predication, will write zeros
603   // to any of the falsely predicated bytes. So, from the loads, we know that
604   // the false lanes are zeroed and here we're trying to track that those false
605   // lanes remain zero, or where they change, the differences are masked away
606   // by their user(s).
607   // All MVE loads and stores have to be predicated, so we know that any load
608   // operands, or stored results are equivalent already. Other explicitly
609   // predicated instructions will perform the same operation in the original
610   // loop and the tail-predicated form too. Because of this, we can insert
611   // loads, stores and other predicated instructions into our Predicated
612   // set and build from there.
613   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
614   SetVector<MachineInstr *> FalseLanesUnknown;
615   SmallPtrSet<MachineInstr *, 4> FalseLanesZero;
616   SmallPtrSet<MachineInstr *, 4> Predicated;
617   MachineBasicBlock *MBB = ML.getHeader();
618 
619   for (auto &MI : *MBB) {
620     const MCInstrDesc &MCID = MI.getDesc();
621     uint64_t Flags = MCID.TSFlags;
622     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
623       continue;
624 
625     if (isVCTP(&MI) || MI.getOpcode() == ARM::MVE_VPST)
626       continue;
627 
628     // Predicated loads will write zeros to the falsely predicated bytes of the
629     // destination register.
630     if (isVectorPredicated(&MI)) {
631       if (MI.mayLoad())
632         FalseLanesZero.insert(&MI);
633       Predicated.insert(&MI);
634       continue;
635     }
636 
637     if (MI.getNumDefs() == 0)
638       continue;
639 
640     if (!producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) {
641       // We require retaining and horizontal operations to operate upon zero'd
642       // false lanes to ensure the conversion doesn't change the output.
643       if (retainsPreviousHalfElement(MI) || isHorizontalReduction(MI))
644         return false;
645       // Otherwise we need to evaluate this instruction later to see whether
646       // unknown false lanes will get masked away by their user(s).
647       FalseLanesUnknown.insert(&MI);
648     } else if (!isHorizontalReduction(MI))
649       FalseLanesZero.insert(&MI);
650   }
651 
652   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
653                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
654     SmallPtrSet<MachineInstr *, 2> Uses;
655     RDA.getGlobalUses(MI, MO.getReg(), Uses);
656     for (auto *Use : Uses) {
657       if (Use != MI && !Predicated.count(Use))
658         return false;
659     }
660     return true;
661   };
662 
663   // Visit the unknowns in reverse so that we can start at the values being
664   // stored and then we can work towards the leaves, hopefully adding more
665   // instructions to Predicated. Successfully terminating the loop means that
666   // all the unknown values have to found to be masked by predicated user(s).
667   for (auto *MI : reverse(FalseLanesUnknown)) {
668     for (auto &MO : MI->operands()) {
669       if (!isRegInClass(MO, QPRs) || !MO.isDef())
670         continue;
671       if (!HasPredicatedUsers(MI, MO, Predicated)) {
672         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
673                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
674         return false;
675       }
676     }
677     // Any unknown false lanes have been masked away by the user(s).
678     Predicated.insert(MI);
679   }
680 
681   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
682   // because they won't be affected by lane predication.
683   SmallSet<Register, 2> LiveOuts;
684   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
685   ML.getExitBlocks(ExitBlocks);
686   for (auto *MBB : ExitBlocks)
687     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
688       if (QPRs->contains(RegMask.PhysReg))
689         LiveOuts.insert(RegMask.PhysReg);
690 
691   // Collect the instructions in the loop body that define the live-out values.
692   SmallPtrSet<MachineInstr *, 2> LiveMIs;
693   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
694   for (auto Reg : LiveOuts)
695     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
696       LiveMIs.insert(MI);
697 
698   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
699              for (auto *MI : LiveMIs)
700                dbgs() << " - " << *MI);
701   // We've already validated that any VPT predication within the loop will be
702   // equivalent when we perform the predication transformation; so we know that
703   // any VPT predicated instruction is predicated upon VCTP. Any live-out
704   // instruction needs to be predicated, so check this here.
705   for (auto *MI : LiveMIs)
706     if (!isVectorPredicated(MI))
707       return false;
708 
709   return true;
710 }
711 
712 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
713   if (Revert)
714     return;
715 
716   if (!End->getOperand(1).isMBB())
717     report_fatal_error("Expected LoopEnd to target basic block");
718 
719   // TODO Maybe there's cases where the target doesn't have to be the header,
720   // but for now be safe and revert.
721   if (End->getOperand(1).getMBB() != ML.getHeader()) {
722     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
723     Revert = true;
724     return;
725   }
726 
727   // The WLS and LE instructions have 12-bits for the label offset. WLS
728   // requires a positive offset, while LE uses negative.
729   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
730       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
731     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
732     Revert = true;
733     return;
734   }
735 
736   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
737       (BBUtils->getOffsetOf(Start) >
738        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
739        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
740     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
741     Revert = true;
742     return;
743   }
744 
745   InsertPt = Revert ? nullptr : isSafeToDefineLR();
746   if (!InsertPt) {
747     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
748     Revert = true;
749     return;
750   } else
751     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
752 
753   if (!IsTailPredicationLegal()) {
754     LLVM_DEBUG(if (!VCTP)
755                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
756                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
757     return;
758   }
759 
760   assert(ML.getBlocks().size() == 1 &&
761          "Shouldn't be processing a loop with more than one block");
762   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
763   LLVM_DEBUG(if (CannotTailPredicate)
764              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
765 }
766 
767 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
768   if (CannotTailPredicate)
769     return false;
770 
771   // Only support a single vctp.
772   if (isVCTP(MI) && VCTP)
773     return false;
774 
775   // Start a new vpt block when we discover a vpt.
776   if (MI->getOpcode() == ARM::MVE_VPST) {
777     VPTBlocks.emplace_back(MI, CurrentPredicate);
778     CurrentBlock = &VPTBlocks.back();
779     return true;
780   } else if (isVCTP(MI))
781     VCTP = MI;
782   else if (MI->getOpcode() == ARM::MVE_VPSEL ||
783            MI->getOpcode() == ARM::MVE_VPNOT)
784     return false;
785 
786   // TODO: Allow VPSEL and VPNOT, we currently cannot because:
787   // 1) It will use the VPR as a predicate operand, but doesn't have to be
788   //    instead a VPT block, which means we can assert while building up
789   //    the VPT block because we don't find another VPST to being a new
790   //    one.
791   // 2) VPSEL still requires a VPR operand even after tail predicating,
792   //    which means we can't remove it unless there is another
793   //    instruction, such as vcmp, that can provide the VPR def.
794 
795   bool IsUse = false;
796   bool IsDef = false;
797   const MCInstrDesc &MCID = MI->getDesc();
798   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
799     const MachineOperand &MO = MI->getOperand(i);
800     if (!MO.isReg() || MO.getReg() != ARM::VPR)
801       continue;
802 
803     if (MO.isDef()) {
804       CurrentPredicate.insert(MI);
805       IsDef = true;
806     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
807       CurrentBlock->addInst(MI, CurrentPredicate);
808       IsUse = true;
809     } else {
810       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
811       return false;
812     }
813   }
814 
815   // If we find a vpr def that is not already predicated on the vctp, we've
816   // got disjoint predicates that may not be equivalent when we do the
817   // conversion.
818   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
819     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
820     return false;
821   }
822 
823   uint64_t Flags = MCID.TSFlags;
824   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
825     return true;
826 
827   // If we find an instruction that has been marked as not valid for tail
828   // predication, only allow the instruction if it's contained within a valid
829   // VPT block.
830   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
831     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
832     return false;
833   }
834 
835   // If the instruction is already explicitly predicated, then the conversion
836   // will be fine, but ensure that all memory operations are predicated.
837   return !IsUse && MI->mayLoadOrStore() ? false : true;
838 }
839 
840 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
841   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
842   if (!ST.hasLOB())
843     return false;
844 
845   MF = &mf;
846   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
847 
848   MLI = &getAnalysis<MachineLoopInfo>();
849   RDA = &getAnalysis<ReachingDefAnalysis>();
850   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
851   MRI = &MF->getRegInfo();
852   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
853   TRI = ST.getRegisterInfo();
854   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
855   BBUtils->computeAllBlockSizes();
856   BBUtils->adjustBBOffsetsAfter(&MF->front());
857 
858   bool Changed = false;
859   for (auto ML : *MLI) {
860     if (!ML->getParentLoop())
861       Changed |= ProcessLoop(ML);
862   }
863   Changed |= RevertNonLoops();
864   return Changed;
865 }
866 
867 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
868 
869   bool Changed = false;
870 
871   // Process inner loops first.
872   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
873     Changed |= ProcessLoop(*I);
874 
875   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
876              if (auto *Preheader = ML->getLoopPreheader())
877                dbgs() << " - " << Preheader->getName() << "\n";
878              else if (auto *Preheader = MLI->findLoopPreheader(ML))
879                dbgs() << " - " << Preheader->getName() << "\n";
880              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
881                dbgs() << " - " << Preheader->getName() << "\n";
882              for (auto *MBB : ML->getBlocks())
883                dbgs() << " - " << MBB->getName() << "\n";
884             );
885 
886   // Search the given block for a loop start instruction. If one isn't found,
887   // and there's only one predecessor block, search that one too.
888   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
889     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
890     for (auto &MI : *MBB) {
891       if (isLoopStart(MI))
892         return &MI;
893     }
894     if (MBB->pred_size() == 1)
895       return SearchForStart(*MBB->pred_begin());
896     return nullptr;
897   };
898 
899   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
900   // Search the preheader for the start intrinsic.
901   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
902   // with potentially multiple set.loop.iterations, so we need to enable this.
903   if (auto *Preheader = ML->getLoopPreheader())
904     LoLoop.Start = SearchForStart(Preheader);
905   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
906     LoLoop.Start = SearchForStart(Preheader);
907   else
908     return false;
909 
910   // Find the low-overhead loop components and decide whether or not to fall
911   // back to a normal loop. Also look for a vctp instructions and decide
912   // whether we can convert that predicate using tail predication.
913   for (auto *MBB : reverse(ML->getBlocks())) {
914     for (auto &MI : *MBB) {
915       if (MI.isDebugValue())
916         continue;
917       else if (MI.getOpcode() == ARM::t2LoopDec)
918         LoLoop.Dec = &MI;
919       else if (MI.getOpcode() == ARM::t2LoopEnd)
920         LoLoop.End = &MI;
921       else if (isLoopStart(MI))
922         LoLoop.Start = &MI;
923       else if (MI.getDesc().isCall()) {
924         // TODO: Though the call will require LE to execute again, does this
925         // mean we should revert? Always executing LE hopefully should be
926         // faster than performing a sub,cmp,br or even subs,br.
927         LoLoop.Revert = true;
928         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
929       } else {
930         // Record VPR defs and build up their corresponding vpt blocks.
931         // Check we know how to tail predicate any mve instructions.
932         LoLoop.AnalyseMVEInst(&MI);
933       }
934     }
935   }
936 
937   LLVM_DEBUG(LoLoop.dump());
938   if (!LoLoop.FoundAllComponents()) {
939     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
940     return false;
941   }
942 
943   // Check that the only instruction using LoopDec is LoopEnd.
944   // TODO: Check for copy chains that really have no effect.
945   SmallPtrSet<MachineInstr*, 2> Uses;
946   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
947   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
948     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
949     LoLoop.Revert = true;
950   }
951   LoLoop.CheckLegality(BBUtils.get());
952   Expand(LoLoop);
953   return true;
954 }
955 
956 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
957 // beq that branches to the exit branch.
958 // TODO: We could also try to generate a cbz if the value in LR is also in
959 // another low register.
960 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
961   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
962   MachineBasicBlock *MBB = MI->getParent();
963   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
964                                     TII->get(ARM::t2CMPri));
965   MIB.add(MI->getOperand(0));
966   MIB.addImm(0);
967   MIB.addImm(ARMCC::AL);
968   MIB.addReg(ARM::NoRegister);
969 
970   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
971   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
972     ARM::tBcc : ARM::t2Bcc;
973 
974   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
975   MIB.add(MI->getOperand(1));   // branch target
976   MIB.addImm(ARMCC::EQ);        // condition code
977   MIB.addReg(ARM::CPSR);
978   MI->eraseFromParent();
979 }
980 
981 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
982   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
983   MachineBasicBlock *MBB = MI->getParent();
984   SmallPtrSet<MachineInstr*, 1> Ignore;
985   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
986     if (I->getOpcode() == ARM::t2LoopEnd) {
987       Ignore.insert(&*I);
988       break;
989     }
990   }
991 
992   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
993   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
994 
995   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
996                                     TII->get(ARM::t2SUBri));
997   MIB.addDef(ARM::LR);
998   MIB.add(MI->getOperand(1));
999   MIB.add(MI->getOperand(2));
1000   MIB.addImm(ARMCC::AL);
1001   MIB.addReg(0);
1002 
1003   if (SetFlags) {
1004     MIB.addReg(ARM::CPSR);
1005     MIB->getOperand(5).setIsDef(true);
1006   } else
1007     MIB.addReg(0);
1008 
1009   MI->eraseFromParent();
1010   return SetFlags;
1011 }
1012 
1013 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1014 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1015   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1016 
1017   MachineBasicBlock *MBB = MI->getParent();
1018   // Create cmp
1019   if (!SkipCmp) {
1020     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1021                                       TII->get(ARM::t2CMPri));
1022     MIB.addReg(ARM::LR);
1023     MIB.addImm(0);
1024     MIB.addImm(ARMCC::AL);
1025     MIB.addReg(ARM::NoRegister);
1026   }
1027 
1028   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1029   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1030     ARM::tBcc : ARM::t2Bcc;
1031 
1032   // Create bne
1033   MachineInstrBuilder MIB =
1034     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1035   MIB.add(MI->getOperand(1));   // branch target
1036   MIB.addImm(ARMCC::NE);        // condition code
1037   MIB.addReg(ARM::CPSR);
1038   MI->eraseFromParent();
1039 }
1040 
1041 // Perform dead code elimation on the loop iteration count setup expression.
1042 // If we are tail-predicating, the number of elements to be processed is the
1043 // operand of the VCTP instruction in the vector body, see getCount(), which is
1044 // register $r3 in this example:
1045 //
1046 //   $lr = big-itercount-expression
1047 //   ..
1048 //   t2DoLoopStart renamable $lr
1049 //   vector.body:
1050 //     ..
1051 //     $vpr = MVE_VCTP32 renamable $r3
1052 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1053 //     t2LoopEnd renamable $lr, %vector.body
1054 //     tB %end
1055 //
1056 // What we would like achieve here is to replace the do-loop start pseudo
1057 // instruction t2DoLoopStart with:
1058 //
1059 //    $lr = MVE_DLSTP_32 killed renamable $r3
1060 //
1061 // Thus, $r3 which defines the number of elements, is written to $lr,
1062 // and then we want to delete the whole chain that used to define $lr,
1063 // see the comment below how this chain could look like.
1064 //
1065 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1066   if (!LoLoop.IsTailPredicationLegal())
1067     return;
1068 
1069   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1070 
1071   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1072   if (!Def) {
1073     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1074     return;
1075   }
1076 
1077   // Collect and remove the users of iteration count.
1078   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1079                                             LoLoop.End, LoLoop.InsertPt };
1080   SmallPtrSet<MachineInstr*, 2> Remove;
1081   if (RDA->isSafeToRemove(Def, Remove, Killed))
1082     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1083   else {
1084     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1085     return;
1086   }
1087 
1088   // Collect the dead code and the MBBs in which they reside.
1089   RDA->collectKilledOperands(Def, Killed);
1090   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1091   for (auto *MI : Killed)
1092     BasicBlocks.insert(MI->getParent());
1093 
1094   // Collect IT blocks in all affected basic blocks.
1095   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1096   for (auto *MBB : BasicBlocks) {
1097     for (auto &MI : *MBB) {
1098       if (MI.getOpcode() != ARM::t2IT)
1099         continue;
1100       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1101     }
1102   }
1103 
1104   // If we're removing all of the instructions within an IT block, then
1105   // also remove the IT instruction.
1106   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1107   for (auto *MI : Killed) {
1108     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1109       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1110       auto &CurrentBlock = ITBlocks[IT];
1111       CurrentBlock.erase(MI);
1112       if (CurrentBlock.empty())
1113         ModifiedITs.erase(IT);
1114       else
1115         ModifiedITs.insert(IT);
1116     }
1117   }
1118 
1119   // Delete the killed instructions only if we don't have any IT blocks that
1120   // need to be modified because we need to fixup the mask.
1121   // TODO: Handle cases where IT blocks are modified.
1122   if (ModifiedITs.empty()) {
1123     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1124                for (auto *MI : Killed)
1125                  dbgs() << " - " << *MI);
1126     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1127   } else
1128     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1129 }
1130 
1131 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1132   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1133   // When using tail-predication, try to delete the dead code that was used to
1134   // calculate the number of loop iterations.
1135   IterationCountDCE(LoLoop);
1136 
1137   MachineInstr *InsertPt = LoLoop.InsertPt;
1138   MachineInstr *Start = LoLoop.Start;
1139   MachineBasicBlock *MBB = InsertPt->getParent();
1140   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1141   unsigned Opc = LoLoop.getStartOpcode();
1142   MachineOperand &Count = LoLoop.getCount();
1143 
1144   MachineInstrBuilder MIB =
1145     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1146 
1147   MIB.addDef(ARM::LR);
1148   MIB.add(Count);
1149   if (!IsDo)
1150     MIB.add(Start->getOperand(1));
1151 
1152   // If we're inserting at a mov lr, then remove it as it's redundant.
1153   if (InsertPt != Start)
1154     LoLoop.ToRemove.insert(InsertPt);
1155   LoLoop.ToRemove.insert(Start);
1156   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1157   return &*MIB;
1158 }
1159 
1160 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1161   auto RemovePredicate = [](MachineInstr *MI) {
1162     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1163     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1164       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1165              "Expected Then predicate!");
1166       MI->getOperand(PIdx).setImm(ARMVCC::None);
1167       MI->getOperand(PIdx+1).setReg(0);
1168     } else
1169       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1170   };
1171 
1172   // There are a few scenarios which we have to fix up:
1173   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
1174   //    defs.
1175   // 2) A VPT block which is only predicated by the vctp but has an internal
1176   //    vpr def.
1177   // 3) A VPT block which is predicated upon the vctp as well as another vpr
1178   //    def.
1179   // 4) A VPT block which is not predicated upon a vctp, but contains it and
1180   //    all instructions within the block are predicated upon in.
1181 
1182   for (auto &Block : LoLoop.getVPTBlocks()) {
1183     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1184     if (Block.HasNonUniformPredicate()) {
1185       PredicatedMI *Divergent = Block.getDivergent();
1186       if (isVCTP(Divergent->MI)) {
1187         // The vctp will be removed, so the block mask of the VPST/VPT will need
1188         // to be recomputed.
1189         LoLoop.BlockMasksToRecompute.insert(Block.getVPST());
1190       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1191         // The VPT block has a non-uniform predicate but it's entry is guarded
1192         // only by a vctp, which means we:
1193         // - Need to remove the original vpst.
1194         // - Then need to unpredicate any following instructions, until
1195         //   we come across the divergent vpr def.
1196         // - Insert a new vpst to predicate the instruction(s) that following
1197         //   the divergent vpr def.
1198         // TODO: We could be producing more VPT blocks than necessary and could
1199         // fold the newly created one into a proceeding one.
1200         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
1201              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1202           RemovePredicate(&*I);
1203 
1204         unsigned Size = 0;
1205         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1206         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1207         MachineInstr *InsertAt = nullptr;
1208         while (I != E) {
1209           InsertAt = &*I;
1210           ++Size;
1211           ++I;
1212         }
1213         // Create a VPST with a null mask, we'll recompute it later.
1214         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1215                                           InsertAt->getDebugLoc(),
1216                                           TII->get(ARM::MVE_VPST));
1217         MIB.addImm(0);
1218         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1219         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1220         LoLoop.ToRemove.insert(Block.getVPST());
1221         LoLoop.BlockMasksToRecompute.insert(MIB.getInstr());
1222       }
1223     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1224       // A vpt block which is only predicated upon vctp and has no internal vpr
1225       // defs:
1226       // - Remove vpst.
1227       // - Unpredicate the remaining instructions.
1228       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1229       LoLoop.ToRemove.insert(Block.getVPST());
1230       for (auto &PredMI : Insts)
1231         RemovePredicate(PredMI.MI);
1232     }
1233   }
1234   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
1235   LoLoop.ToRemove.insert(LoLoop.VCTP);
1236 }
1237 
1238 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1239 
1240   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1241   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1242     MachineInstr *End = LoLoop.End;
1243     MachineBasicBlock *MBB = End->getParent();
1244     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1245       ARM::MVE_LETP : ARM::t2LEUpdate;
1246     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1247                                       TII->get(Opc));
1248     MIB.addDef(ARM::LR);
1249     MIB.add(End->getOperand(0));
1250     MIB.add(End->getOperand(1));
1251     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1252     LoLoop.ToRemove.insert(LoLoop.Dec);
1253     LoLoop.ToRemove.insert(End);
1254     return &*MIB;
1255   };
1256 
1257   // TODO: We should be able to automatically remove these branches before we
1258   // get here - probably by teaching analyzeBranch about the pseudo
1259   // instructions.
1260   // If there is an unconditional branch, after I, that just branches to the
1261   // next block, remove it.
1262   auto RemoveDeadBranch = [](MachineInstr *I) {
1263     MachineBasicBlock *BB = I->getParent();
1264     MachineInstr *Terminator = &BB->instr_back();
1265     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1266       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1267       if (BB->isLayoutSuccessor(Succ)) {
1268         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1269         Terminator->eraseFromParent();
1270       }
1271     }
1272   };
1273 
1274   if (LoLoop.Revert) {
1275     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1276       RevertWhile(LoLoop.Start);
1277     else
1278       LoLoop.Start->eraseFromParent();
1279     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1280     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1281   } else {
1282     LoLoop.Start = ExpandLoopStart(LoLoop);
1283     RemoveDeadBranch(LoLoop.Start);
1284     LoLoop.End = ExpandLoopEnd(LoLoop);
1285     RemoveDeadBranch(LoLoop.End);
1286     if (LoLoop.IsTailPredicationLegal())
1287       ConvertVPTBlocks(LoLoop);
1288     for (auto *I : LoLoop.ToRemove) {
1289       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1290       I->eraseFromParent();
1291     }
1292     for (auto *I : LoLoop.BlockMasksToRecompute) {
1293       LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I);
1294       recomputeVPTBlockMask(*I);
1295       LLVM_DEBUG(dbgs() << "           ... done: " << *I);
1296     }
1297   }
1298 
1299   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1300   DFS.ProcessLoop();
1301   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1302   for (auto *MBB : PostOrder) {
1303     recomputeLiveIns(*MBB);
1304     // FIXME: For some reason, the live-in print order is non-deterministic for
1305     // our tests and I can't out why... So just sort them.
1306     MBB->sortUniqueLiveIns();
1307   }
1308 
1309   for (auto *MBB : reverse(PostOrder))
1310     recomputeLivenessFlags(*MBB);
1311 
1312   // We've moved, removed and inserted new instructions, so update RDA.
1313   RDA->reset();
1314 }
1315 
1316 bool ARMLowOverheadLoops::RevertNonLoops() {
1317   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1318   bool Changed = false;
1319 
1320   for (auto &MBB : *MF) {
1321     SmallVector<MachineInstr*, 4> Starts;
1322     SmallVector<MachineInstr*, 4> Decs;
1323     SmallVector<MachineInstr*, 4> Ends;
1324 
1325     for (auto &I : MBB) {
1326       if (isLoopStart(I))
1327         Starts.push_back(&I);
1328       else if (I.getOpcode() == ARM::t2LoopDec)
1329         Decs.push_back(&I);
1330       else if (I.getOpcode() == ARM::t2LoopEnd)
1331         Ends.push_back(&I);
1332     }
1333 
1334     if (Starts.empty() && Decs.empty() && Ends.empty())
1335       continue;
1336 
1337     Changed = true;
1338 
1339     for (auto *Start : Starts) {
1340       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1341         RevertWhile(Start);
1342       else
1343         Start->eraseFromParent();
1344     }
1345     for (auto *Dec : Decs)
1346       RevertLoopDec(Dec);
1347 
1348     for (auto *End : Ends)
1349       RevertLoopEnd(End);
1350   }
1351   return Changed;
1352 }
1353 
1354 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1355   return new ARMLowOverheadLoops();
1356 }
1357