xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
15ffd83dbSDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
28bcb0991SDimitry Andric //
38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68bcb0991SDimitry Andric //
78bcb0991SDimitry Andric //===----------------------------------------------------------------------===//
88bcb0991SDimitry Andric //
98bcb0991SDimitry Andric /// \file
108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
115ffd83dbSDimitry Andric /// branches to help accelerate DSP applications. These two extensions,
125ffd83dbSDimitry Andric /// combined with a new form of predication called tail-predication, can be used
135ffd83dbSDimitry Andric /// to provide implicit vector predication within a low-overhead loop.
145ffd83dbSDimitry Andric /// This is implicit because the predicate of active/inactive lanes is
155ffd83dbSDimitry Andric /// calculated by hardware, and thus does not need to be explicitly passed
165ffd83dbSDimitry Andric /// to vector instructions. The instructions responsible for this are the
175ffd83dbSDimitry Andric /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
185ffd83dbSDimitry Andric /// the total number of data elements processed by the loop. The loop-end
195ffd83dbSDimitry Andric /// LETP instruction is responsible for decrementing and setting the remaining
205ffd83dbSDimitry Andric /// elements to be processed and generating the mask of active lanes.
215ffd83dbSDimitry Andric ///
228bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the
238bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is
248bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are
25e8d8bef9SDimitry Andric /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
26e8d8bef9SDimitry Andric /// get.active.lane.mask intrinsic and attempts to convert them to VCTP
27e8d8bef9SDimitry Andric /// instructions. This will be picked up by the ARM Low-overhead loop pass later
28e8d8bef9SDimitry Andric /// in the backend, which performs the final transformation to a DLSTP or WLSTP
29e8d8bef9SDimitry Andric /// tail-predicated loop.
30e8d8bef9SDimitry Andric //
31e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
328bcb0991SDimitry Andric 
33480093f4SDimitry Andric #include "ARM.h"
34480093f4SDimitry Andric #include "ARMSubtarget.h"
355ffd83dbSDimitry Andric #include "ARMTargetTransformInfo.h"
368bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
378bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h"
388bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
398bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
405ffd83dbSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h"
418bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
4206c3fb27SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
438bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
448bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h"
45480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
46480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
478bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h"
485ffd83dbSDimitry Andric #include "llvm/InitializePasses.h"
498bcb0991SDimitry Andric #include "llvm/Support/Debug.h"
508bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/Local.h"
525ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
535ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
548bcb0991SDimitry Andric 
558bcb0991SDimitry Andric using namespace llvm;
568bcb0991SDimitry Andric 
578bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication"
588bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication"
598bcb0991SDimitry Andric 
605ffd83dbSDimitry Andric cl::opt<TailPredication::Mode> EnableTailPredication(
61e8d8bef9SDimitry Andric    "tail-predication", cl::desc("MVE tail-predication pass options"),
62e8d8bef9SDimitry Andric    cl::init(TailPredication::Enabled),
635ffd83dbSDimitry Andric    cl::values(clEnumValN(TailPredication::Disabled, "disabled",
645ffd83dbSDimitry Andric                          "Don't tail-predicate loops"),
655ffd83dbSDimitry Andric               clEnumValN(TailPredication::EnabledNoReductions,
665ffd83dbSDimitry Andric                          "enabled-no-reductions",
675ffd83dbSDimitry Andric                          "Enable tail-predication, but not for reduction loops"),
685ffd83dbSDimitry Andric               clEnumValN(TailPredication::Enabled,
695ffd83dbSDimitry Andric                          "enabled",
705ffd83dbSDimitry Andric                          "Enable tail-predication, including reduction loops"),
715ffd83dbSDimitry Andric               clEnumValN(TailPredication::ForceEnabledNoReductions,
725ffd83dbSDimitry Andric                          "force-enabled-no-reductions",
735ffd83dbSDimitry Andric                          "Enable tail-predication, but not for reduction loops, "
745ffd83dbSDimitry Andric                          "and force this which might be unsafe"),
755ffd83dbSDimitry Andric               clEnumValN(TailPredication::ForceEnabled,
765ffd83dbSDimitry Andric                          "force-enabled",
775ffd83dbSDimitry Andric                          "Enable tail-predication, including reduction loops, "
785ffd83dbSDimitry Andric                          "and force this which might be unsafe")));
795ffd83dbSDimitry Andric 
805ffd83dbSDimitry Andric 
818bcb0991SDimitry Andric namespace {
828bcb0991SDimitry Andric 
838bcb0991SDimitry Andric class MVETailPredication : public LoopPass {
848bcb0991SDimitry Andric   SmallVector<IntrinsicInst*, 4> MaskedInsts;
858bcb0991SDimitry Andric   Loop *L = nullptr;
868bcb0991SDimitry Andric   ScalarEvolution *SE = nullptr;
878bcb0991SDimitry Andric   TargetTransformInfo *TTI = nullptr;
885ffd83dbSDimitry Andric   const ARMSubtarget *ST = nullptr;
898bcb0991SDimitry Andric 
908bcb0991SDimitry Andric public:
918bcb0991SDimitry Andric   static char ID;
928bcb0991SDimitry Andric 
938bcb0991SDimitry Andric   MVETailPredication() : LoopPass(ID) { }
948bcb0991SDimitry Andric 
958bcb0991SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
968bcb0991SDimitry Andric     AU.addRequired<ScalarEvolutionWrapperPass>();
978bcb0991SDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
988bcb0991SDimitry Andric     AU.addRequired<TargetPassConfig>();
998bcb0991SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
1008bcb0991SDimitry Andric     AU.addPreserved<LoopInfoWrapperPass>();
1018bcb0991SDimitry Andric     AU.setPreservesCFG();
1028bcb0991SDimitry Andric   }
1038bcb0991SDimitry Andric 
1048bcb0991SDimitry Andric   bool runOnLoop(Loop *L, LPPassManager&) override;
1058bcb0991SDimitry Andric 
1068bcb0991SDimitry Andric private:
107e8d8bef9SDimitry Andric   /// Perform the relevant checks on the loop and convert active lane masks if
108e8d8bef9SDimitry Andric   /// possible.
109e8d8bef9SDimitry Andric   bool TryConvertActiveLaneMask(Value *TripCount);
1108bcb0991SDimitry Andric 
111e8d8bef9SDimitry Andric   /// Perform several checks on the arguments of @llvm.get.active.lane.mask
112e8d8bef9SDimitry Andric   /// intrinsic. E.g., check that the loop induction variable and the element
113e8d8bef9SDimitry Andric   /// count are of the form we expect, and also perform overflow checks for
114e8d8bef9SDimitry Andric   /// the new expressions that are created.
11506c3fb27SDimitry Andric   const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
116480093f4SDimitry Andric 
117480093f4SDimitry Andric   /// Insert the intrinsic to represent the effect of tail predication.
11806c3fb27SDimitry Andric   void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *Start);
1198bcb0991SDimitry Andric };
1208bcb0991SDimitry Andric 
1218bcb0991SDimitry Andric } // end namespace
1228bcb0991SDimitry Andric 
1238bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
1245ffd83dbSDimitry Andric   if (skipLoop(L) || !EnableTailPredication)
1258bcb0991SDimitry Andric     return false;
1268bcb0991SDimitry Andric 
1275ffd83dbSDimitry Andric   MaskedInsts.clear();
1288bcb0991SDimitry Andric   Function &F = *L->getHeader()->getParent();
1298bcb0991SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
1308bcb0991SDimitry Andric   auto &TM = TPC.getTM<TargetMachine>();
1315ffd83dbSDimitry Andric   ST = &TM.getSubtarget<ARMSubtarget>(F);
1328bcb0991SDimitry Andric   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1338bcb0991SDimitry Andric   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1348bcb0991SDimitry Andric   this->L = L;
1358bcb0991SDimitry Andric 
1368bcb0991SDimitry Andric   // The MVE and LOB extensions are combined to enable tail-predication, but
1378bcb0991SDimitry Andric   // there's nothing preventing us from generating VCTP instructions for v8.1m.
1388bcb0991SDimitry Andric   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
139480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
1408bcb0991SDimitry Andric     return false;
1418bcb0991SDimitry Andric   }
1428bcb0991SDimitry Andric 
1438bcb0991SDimitry Andric   BasicBlock *Preheader = L->getLoopPreheader();
1448bcb0991SDimitry Andric   if (!Preheader)
1458bcb0991SDimitry Andric     return false;
1468bcb0991SDimitry Andric 
1478bcb0991SDimitry Andric   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
1488bcb0991SDimitry Andric     for (auto &I : *BB) {
1498bcb0991SDimitry Andric       auto *Call = dyn_cast<IntrinsicInst>(&I);
1508bcb0991SDimitry Andric       if (!Call)
1518bcb0991SDimitry Andric         continue;
1528bcb0991SDimitry Andric 
1538bcb0991SDimitry Andric       Intrinsic::ID ID = Call->getIntrinsicID();
154e8d8bef9SDimitry Andric       if (ID == Intrinsic::start_loop_iterations ||
155fe6060f1SDimitry Andric           ID == Intrinsic::test_start_loop_iterations)
1568bcb0991SDimitry Andric         return cast<IntrinsicInst>(&I);
1578bcb0991SDimitry Andric     }
1588bcb0991SDimitry Andric     return nullptr;
1598bcb0991SDimitry Andric   };
1608bcb0991SDimitry Andric 
1618bcb0991SDimitry Andric   // Look for the hardware loop intrinsic that sets the iteration count.
1628bcb0991SDimitry Andric   IntrinsicInst *Setup = FindLoopIterations(Preheader);
1638bcb0991SDimitry Andric 
1648bcb0991SDimitry Andric   // The test.set iteration could live in the pre-preheader.
1658bcb0991SDimitry Andric   if (!Setup) {
1668bcb0991SDimitry Andric     if (!Preheader->getSinglePredecessor())
1678bcb0991SDimitry Andric       return false;
1688bcb0991SDimitry Andric     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
1698bcb0991SDimitry Andric     if (!Setup)
1708bcb0991SDimitry Andric       return false;
1718bcb0991SDimitry Andric   }
1728bcb0991SDimitry Andric 
173e8d8bef9SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
1748bcb0991SDimitry Andric 
175e8d8bef9SDimitry Andric   bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
1768bcb0991SDimitry Andric 
177e8d8bef9SDimitry Andric   return Changed;
1788bcb0991SDimitry Andric }
1798bcb0991SDimitry Andric 
1805ffd83dbSDimitry Andric // The active lane intrinsic has this form:
1815ffd83dbSDimitry Andric //
182e8d8bef9SDimitry Andric //    @llvm.get.active.lane.mask(IV, TC)
1835ffd83dbSDimitry Andric //
1845ffd83dbSDimitry Andric // Here we perform checks that this intrinsic behaves as expected,
1855ffd83dbSDimitry Andric // which means:
1865ffd83dbSDimitry Andric //
187e8d8bef9SDimitry Andric // 1) Check that the TripCount (TC) belongs to this loop (originally).
188e8d8bef9SDimitry Andric // 2) The element count (TC) needs to be sufficiently large that the decrement
189e8d8bef9SDimitry Andric //    of element counter doesn't overflow, which means that we need to prove:
1905ffd83dbSDimitry Andric //        ceil(ElementCount / VectorWidth) >= TripCount
1915ffd83dbSDimitry Andric //    by rounding up ElementCount up:
1925ffd83dbSDimitry Andric //        ((ElementCount + (VectorWidth - 1)) / VectorWidth
1935ffd83dbSDimitry Andric //    and evaluate if expression isKnownNonNegative:
1945ffd83dbSDimitry Andric //        (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
1955ffd83dbSDimitry Andric // 3) The IV must be an induction phi with an increment equal to the
1965ffd83dbSDimitry Andric //    vector width.
19706c3fb27SDimitry Andric const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
198e8d8bef9SDimitry Andric                                                  Value *TripCount) {
1995ffd83dbSDimitry Andric   bool ForceTailPredication =
2005ffd83dbSDimitry Andric     EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
2015ffd83dbSDimitry Andric     EnableTailPredication == TailPredication::ForceEnabled;
2025ffd83dbSDimitry Andric 
203e8d8bef9SDimitry Andric   Value *ElemCount = ActiveLaneMask->getOperand(1);
2044652422eSDimitry Andric   bool Changed = false;
2054652422eSDimitry Andric   if (!L->makeLoopInvariant(ElemCount, Changed))
20606c3fb27SDimitry Andric     return nullptr;
2074652422eSDimitry Andric 
208e8d8bef9SDimitry Andric   auto *EC= SE->getSCEV(ElemCount);
2095ffd83dbSDimitry Andric   auto *TC = SE->getSCEV(TripCount);
210e8d8bef9SDimitry Andric   int VectorWidth =
211e8d8bef9SDimitry Andric       cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
2120eae32dcSDimitry Andric   if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
2130eae32dcSDimitry Andric       VectorWidth != 16)
21406c3fb27SDimitry Andric     return nullptr;
215e8d8bef9SDimitry Andric   ConstantInt *ConstElemCount = nullptr;
2165ffd83dbSDimitry Andric 
217e8d8bef9SDimitry Andric   // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
218e8d8bef9SDimitry Andric   // this loop.  The scalar tripcount corresponds the number of elements
219e8d8bef9SDimitry Andric   // processed by the loop, so we will refer to that from this point on.
220e8d8bef9SDimitry Andric   if (!SE->isLoopInvariant(EC, L)) {
221e8d8bef9SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
22206c3fb27SDimitry Andric     return nullptr;
22306c3fb27SDimitry Andric   }
22406c3fb27SDimitry Andric 
22506c3fb27SDimitry Andric   // 2) Find out if IV is an induction phi. Note that we can't use Loop
22606c3fb27SDimitry Andric   // helpers here to get the induction variable, because the hardware loop is
22706c3fb27SDimitry Andric   // no longer in loopsimplify form, and also the hwloop intrinsic uses a
22806c3fb27SDimitry Andric   // different counter. Using SCEV, we check that the induction is of the
22906c3fb27SDimitry Andric   // form i = i + 4, where the increment must be equal to the VectorWidth.
23006c3fb27SDimitry Andric   auto *IV = ActiveLaneMask->getOperand(0);
23106c3fb27SDimitry Andric   auto *IVExpr = SE->getSCEV(IV);
23206c3fb27SDimitry Andric   auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
23306c3fb27SDimitry Andric 
23406c3fb27SDimitry Andric   if (!AddExpr) {
23506c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
23606c3fb27SDimitry Andric     return nullptr;
23706c3fb27SDimitry Andric   }
23806c3fb27SDimitry Andric   // Check that this AddRec is associated with this loop.
23906c3fb27SDimitry Andric   if (AddExpr->getLoop() != L) {
24006c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
24106c3fb27SDimitry Andric     return nullptr;
24206c3fb27SDimitry Andric   }
24306c3fb27SDimitry Andric   auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
24406c3fb27SDimitry Andric   if (!Step) {
24506c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
24606c3fb27SDimitry Andric                AddExpr->getOperand(1)->dump());
24706c3fb27SDimitry Andric     return nullptr;
24806c3fb27SDimitry Andric   }
24906c3fb27SDimitry Andric   auto StepValue = Step->getValue()->getSExtValue();
25006c3fb27SDimitry Andric   if (VectorWidth != StepValue) {
25106c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
25206c3fb27SDimitry Andric                       << " doesn't match vector width " << VectorWidth << "\n");
25306c3fb27SDimitry Andric     return nullptr;
2545ffd83dbSDimitry Andric   }
2555ffd83dbSDimitry Andric 
256e8d8bef9SDimitry Andric   if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
257e8d8bef9SDimitry Andric     ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
258e8d8bef9SDimitry Andric     if (!TC) {
259e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
260e8d8bef9SDimitry Andric                            "set.loop.iterations\n");
26106c3fb27SDimitry Andric       return nullptr;
262e8d8bef9SDimitry Andric     }
263e8d8bef9SDimitry Andric 
264e8d8bef9SDimitry Andric     // Calculate 2 tripcount values and check that they are consistent with
265e8d8bef9SDimitry Andric     // each other. The TripCount for a predicated vector loop body is
266e8d8bef9SDimitry Andric     // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
267e8d8bef9SDimitry Andric     // work it out here.
268e8d8bef9SDimitry Andric     uint64_t TC1 = TC->getZExtValue();
269e8d8bef9SDimitry Andric     uint64_t TC2 =
270e8d8bef9SDimitry Andric         (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
271e8d8bef9SDimitry Andric 
272e8d8bef9SDimitry Andric     // If the tripcount values are inconsistent, we can't insert the VCTP and
273e8d8bef9SDimitry Andric     // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
274e8d8bef9SDimitry Andric     // and legalize this.
275e8d8bef9SDimitry Andric     if (TC1 != TC2) {
276e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
277e8d8bef9SDimitry Andric                  << TC1 << " from set.loop.iterations, and "
278e8d8bef9SDimitry Andric                  << TC2 << " from get.active.lane.mask\n");
27906c3fb27SDimitry Andric       return nullptr;
280e8d8bef9SDimitry Andric     }
281e8d8bef9SDimitry Andric   } else if (!ForceTailPredication) {
28206c3fb27SDimitry Andric     // 3) We need to prove that the sub expression that we create in the
283e8d8bef9SDimitry Andric     // tail-predicated loop body, which calculates the remaining elements to be
284e8d8bef9SDimitry Andric     // processed, is non-negative, i.e. it doesn't overflow:
2855ffd83dbSDimitry Andric     //
286e8d8bef9SDimitry Andric     //   ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
2875ffd83dbSDimitry Andric     //
288e8d8bef9SDimitry Andric     // This is true if:
2895ffd83dbSDimitry Andric     //
290e8d8bef9SDimitry Andric     //    TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
2915ffd83dbSDimitry Andric     //
292e8d8bef9SDimitry Andric     // which what we will be using here.
2935ffd83dbSDimitry Andric     //
294e8d8bef9SDimitry Andric     auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
295e8d8bef9SDimitry Andric     // ElementCount + (VW-1):
29606c3fb27SDimitry Andric     auto *Start = AddExpr->getStart();
297e8d8bef9SDimitry Andric     auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
2985ffd83dbSDimitry Andric         SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
2995ffd83dbSDimitry Andric 
300e8d8bef9SDimitry Andric     // Ceil = ElementCount + (VW-1) / VW
301e8d8bef9SDimitry Andric     auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
302e8d8bef9SDimitry Andric 
303e8d8bef9SDimitry Andric     // Prevent unused variable warnings with TC
304e8d8bef9SDimitry Andric     (void)TC;
30506c3fb27SDimitry Andric     LLVM_DEBUG({
306e8d8bef9SDimitry Andric       dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
30706c3fb27SDimitry Andric       dbgs() << "ARM TP: - TripCount = " << *TC << "\n";
30806c3fb27SDimitry Andric       dbgs() << "ARM TP: - ElemCount = " << *EC << "\n";
30906c3fb27SDimitry Andric       dbgs() << "ARM TP: - Start = " << *Start << "\n";
31006c3fb27SDimitry Andric       dbgs() << "ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) << "\n";
311e8d8bef9SDimitry Andric       dbgs() << "ARM TP: - VecWidth =  " << VectorWidth << "\n";
31206c3fb27SDimitry Andric       dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil << "\n";
31306c3fb27SDimitry Andric     });
314e8d8bef9SDimitry Andric 
315e8d8bef9SDimitry Andric     // As an example, almost all the tripcount expressions (produced by the
316e8d8bef9SDimitry Andric     // vectoriser) look like this:
317e8d8bef9SDimitry Andric     //
31806c3fb27SDimitry Andric     //   TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4)
319e8d8bef9SDimitry Andric     //
320e8d8bef9SDimitry Andric     // and "ElementCount + (VW-1) / VW":
321e8d8bef9SDimitry Andric     //
322e8d8bef9SDimitry Andric     //   Ceil = ((3 + %N) /u 4)
323e8d8bef9SDimitry Andric     //
324e8d8bef9SDimitry Andric     // Check for equality of TC and Ceil by calculating SCEV expression
325e8d8bef9SDimitry Andric     // TC - Ceil and test it for zero.
326e8d8bef9SDimitry Andric     //
32706c3fb27SDimitry Andric     const SCEV *Div = SE->getUDivExpr(
32806c3fb27SDimitry Andric         SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
32906c3fb27SDimitry Andric                        SE->getNegativeSCEV(Start)),
33006c3fb27SDimitry Andric         VW);
33106c3fb27SDimitry Andric     const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
33206c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: - Sub       = "; Sub->dump());
333e8d8bef9SDimitry Andric 
334349cc55cSDimitry Andric     // Use context sensitive facts about the path to the loop to refine.  This
335349cc55cSDimitry Andric     // comes up as the backedge taken count can incorporate context sensitive
336349cc55cSDimitry Andric     // reasoning, and our RHS just above doesn't.
337349cc55cSDimitry Andric     Sub = SE->applyLoopGuards(Sub, L);
33806c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump());
339349cc55cSDimitry Andric 
340349cc55cSDimitry Andric     if (!Sub->isZero()) {
341e8d8bef9SDimitry Andric       LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
34206c3fb27SDimitry Andric       return nullptr;
3435ffd83dbSDimitry Andric     }
344e8d8bef9SDimitry Andric   }
3455ffd83dbSDimitry Andric 
34606c3fb27SDimitry Andric   // Check that the start value is a multiple of the VectorWidth.
34706c3fb27SDimitry Andric   // TODO: This could do with a method to check if the scev is a multiple of
34806c3fb27SDimitry Andric   // VectorWidth. For the moment we just check for constants, muls and unknowns
34906c3fb27SDimitry Andric   // (which use MaskedValueIsZero and seems to be the most common).
35006c3fb27SDimitry Andric   if (auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
35106c3fb27SDimitry Andric     if (BaseC->getAPInt().urem(VectorWidth) == 0)
35206c3fb27SDimitry Andric       return SE->getMinusSCEV(EC, BaseC);
35306c3fb27SDimitry Andric   } else if (auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
35406c3fb27SDimitry Andric     Type *Ty = BaseV->getType();
35506c3fb27SDimitry Andric     APInt Mask = APInt::getLowBitsSet(Ty->getPrimitiveSizeInBits(),
35606c3fb27SDimitry Andric                                       Log2_64(VectorWidth));
35706c3fb27SDimitry Andric     if (MaskedValueIsZero(BaseV->getValue(), Mask,
358*0fca6ea1SDimitry Andric                           L->getHeader()->getDataLayout()))
35906c3fb27SDimitry Andric       return SE->getMinusSCEV(EC, BaseV);
36006c3fb27SDimitry Andric   } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
36106c3fb27SDimitry Andric     if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
36206c3fb27SDimitry Andric       if (BaseC->getAPInt().urem(VectorWidth) == 0)
36306c3fb27SDimitry Andric         return SE->getMinusSCEV(EC, BaseC);
36406c3fb27SDimitry Andric     if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
36506c3fb27SDimitry Andric       if (BaseC->getAPInt().urem(VectorWidth) == 0)
36606c3fb27SDimitry Andric         return SE->getMinusSCEV(EC, BaseC);
36706c3fb27SDimitry Andric   }
368e8d8bef9SDimitry Andric 
36906c3fb27SDimitry Andric   LLVM_DEBUG(
37006c3fb27SDimitry Andric       dbgs() << "ARM TP: induction base is not know to be a multiple of VF: "
37106c3fb27SDimitry Andric              << *AddExpr->getOperand(0) << "\n");
37206c3fb27SDimitry Andric   return nullptr;
3735ffd83dbSDimitry Andric }
3745ffd83dbSDimitry Andric 
3755ffd83dbSDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
37606c3fb27SDimitry Andric                                              Value *Start) {
3775ffd83dbSDimitry Andric   IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
378480093f4SDimitry Andric   Module *M = L->getHeader()->getModule();
379480093f4SDimitry Andric   Type *Ty = IntegerType::get(M->getContext(), 32);
380e8d8bef9SDimitry Andric   unsigned VectorWidth =
381e8d8bef9SDimitry Andric       cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
3828bcb0991SDimitry Andric 
383480093f4SDimitry Andric   // Insert a phi to count the number of elements processed by the loop.
3845f757f3fSDimitry Andric   Builder.SetInsertPoint(L->getHeader(), L->getHeader()->getFirstNonPHIIt());
385480093f4SDimitry Andric   PHINode *Processed = Builder.CreatePHI(Ty, 2);
38606c3fb27SDimitry Andric   Processed->addIncoming(Start, L->getLoopPreheader());
387480093f4SDimitry Andric 
388e8d8bef9SDimitry Andric   // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
389e8d8bef9SDimitry Andric   // thus represent the effect of tail predication.
3905ffd83dbSDimitry Andric   Builder.SetInsertPoint(ActiveLaneMask);
391e8d8bef9SDimitry Andric   ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
392480093f4SDimitry Andric 
393480093f4SDimitry Andric   Intrinsic::ID VCTPID;
3945ffd83dbSDimitry Andric   switch (VectorWidth) {
395480093f4SDimitry Andric   default:
396480093f4SDimitry Andric     llvm_unreachable("unexpected number of lanes");
3970eae32dcSDimitry Andric   case 2:  VCTPID = Intrinsic::arm_mve_vctp64; break;
398480093f4SDimitry Andric   case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
399480093f4SDimitry Andric   case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
400480093f4SDimitry Andric   case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
401480093f4SDimitry Andric   }
402480093f4SDimitry Andric   Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
4035ffd83dbSDimitry Andric   Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
4045ffd83dbSDimitry Andric   ActiveLaneMask->replaceAllUsesWith(VCTPCall);
405480093f4SDimitry Andric 
406480093f4SDimitry Andric   // Add the incoming value to the new phi.
407480093f4SDimitry Andric   // TODO: This add likely already exists in the loop.
408480093f4SDimitry Andric   Value *Remaining = Builder.CreateSub(Processed, Factor);
409480093f4SDimitry Andric   Processed->addIncoming(Remaining, L->getLoopLatch());
410480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
411480093f4SDimitry Andric              << *Processed << "\n"
4125ffd83dbSDimitry Andric              << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
413480093f4SDimitry Andric }
414480093f4SDimitry Andric 
415e8d8bef9SDimitry Andric bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
416e8d8bef9SDimitry Andric   SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
417e8d8bef9SDimitry Andric   for (auto *BB : L->getBlocks())
418e8d8bef9SDimitry Andric     for (auto &I : *BB)
419e8d8bef9SDimitry Andric       if (auto *Int = dyn_cast<IntrinsicInst>(&I))
420e8d8bef9SDimitry Andric         if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
421e8d8bef9SDimitry Andric           ActiveLaneMasks.push_back(Int);
422e8d8bef9SDimitry Andric 
423e8d8bef9SDimitry Andric   if (ActiveLaneMasks.empty())
424480093f4SDimitry Andric     return false;
425480093f4SDimitry Andric 
426480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
4278bcb0991SDimitry Andric 
428e8d8bef9SDimitry Andric   for (auto *ActiveLaneMask : ActiveLaneMasks) {
4295ffd83dbSDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
4305ffd83dbSDimitry Andric                       << *ActiveLaneMask << "\n");
4318bcb0991SDimitry Andric 
43206c3fb27SDimitry Andric     const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
43306c3fb27SDimitry Andric     if (!StartSCEV) {
4345ffd83dbSDimitry Andric       LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
4355ffd83dbSDimitry Andric       return false;
4365ffd83dbSDimitry Andric     }
43706c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
43806c3fb27SDimitry Andric                       << "\n");
439*0fca6ea1SDimitry Andric     SCEVExpander Expander(*SE, L->getHeader()->getDataLayout(),
44006c3fb27SDimitry Andric                           "start");
44106c3fb27SDimitry Andric     Instruction *Ins = L->getLoopPreheader()->getTerminator();
44206c3fb27SDimitry Andric     Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->getType(), Ins);
44306c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Created start value " << *Start << "\n");
44406c3fb27SDimitry Andric     InsertVCTPIntrinsic(ActiveLaneMask, Start);
4458bcb0991SDimitry Andric   }
4468bcb0991SDimitry Andric 
447e8d8bef9SDimitry Andric   // Remove dead instructions and now dead phis.
448e8d8bef9SDimitry Andric   for (auto *II : ActiveLaneMasks)
449e8d8bef9SDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(II);
450bdd1243dSDimitry Andric   for (auto *I : L->blocks())
451e8d8bef9SDimitry Andric     DeleteDeadPHIs(I);
4528bcb0991SDimitry Andric   return true;
4538bcb0991SDimitry Andric }
4548bcb0991SDimitry Andric 
4558bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() {
4568bcb0991SDimitry Andric   return new MVETailPredication();
4578bcb0991SDimitry Andric }
4588bcb0991SDimitry Andric 
4598bcb0991SDimitry Andric char MVETailPredication::ID = 0;
4608bcb0991SDimitry Andric 
4618bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
4628bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
463