xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp (revision 480093f4440d54b30b3025afeac24b48f2ba7a2e)
18bcb0991SDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===//
28bcb0991SDimitry Andric //
38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68bcb0991SDimitry Andric //
78bcb0991SDimitry Andric //===----------------------------------------------------------------------===//
88bcb0991SDimitry Andric //
98bcb0991SDimitry Andric /// \file
108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
118bcb0991SDimitry Andric /// branches to help accelerate DSP applications. These two extensions can be
128bcb0991SDimitry Andric /// combined to provide implicit vector predication within a low-overhead loop.
138bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the
148bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is
158bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are
168bcb0991SDimitry Andric /// predicated upon the iteration counter. This pass looks at these predicated
178bcb0991SDimitry Andric /// vector loops, that are targets for low-overhead loops, and prepares it for
188bcb0991SDimitry Andric /// code generation. Once the vectorizer has produced a masked loop, there's a
198bcb0991SDimitry Andric /// couple of final forms:
208bcb0991SDimitry Andric /// - A tail-predicated loop, with implicit predication.
218bcb0991SDimitry Andric /// - A loop containing multiple VCPT instructions, predicating multiple VPT
228bcb0991SDimitry Andric ///   blocks of instructions operating on different vector types.
23*480093f4SDimitry Andric ///
24*480093f4SDimitry Andric /// This pass inserts the inserts the VCTP intrinsic to represent the effect of
25*480093f4SDimitry Andric /// tail predication. This will be picked up by the ARM Low-overhead loop pass,
26*480093f4SDimitry Andric /// which performs the final transformation to a DLSTP or WLSTP tail-predicated
27*480093f4SDimitry Andric /// loop.
288bcb0991SDimitry Andric 
29*480093f4SDimitry Andric #include "ARM.h"
30*480093f4SDimitry Andric #include "ARMSubtarget.h"
318bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
328bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h"
338bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
348bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpander.h"
358bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
368bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
378bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
388bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h"
39*480093f4SDimitry Andric #include "llvm/IR/Instructions.h"
40*480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h"
418bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h"
428bcb0991SDimitry Andric #include "llvm/Support/Debug.h"
438bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
448bcb0991SDimitry Andric 
458bcb0991SDimitry Andric using namespace llvm;
468bcb0991SDimitry Andric 
478bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication"
488bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication"
498bcb0991SDimitry Andric 
50*480093f4SDimitry Andric cl::opt<bool>
518bcb0991SDimitry Andric DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
528bcb0991SDimitry Andric                        cl::init(true),
538bcb0991SDimitry Andric                        cl::desc("Disable MVE Tail Predication"));
548bcb0991SDimitry Andric namespace {
558bcb0991SDimitry Andric 
568bcb0991SDimitry Andric class MVETailPredication : public LoopPass {
578bcb0991SDimitry Andric   SmallVector<IntrinsicInst*, 4> MaskedInsts;
588bcb0991SDimitry Andric   Loop *L = nullptr;
598bcb0991SDimitry Andric   ScalarEvolution *SE = nullptr;
608bcb0991SDimitry Andric   TargetTransformInfo *TTI = nullptr;
618bcb0991SDimitry Andric 
628bcb0991SDimitry Andric public:
638bcb0991SDimitry Andric   static char ID;
648bcb0991SDimitry Andric 
658bcb0991SDimitry Andric   MVETailPredication() : LoopPass(ID) { }
668bcb0991SDimitry Andric 
678bcb0991SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
688bcb0991SDimitry Andric     AU.addRequired<ScalarEvolutionWrapperPass>();
698bcb0991SDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
708bcb0991SDimitry Andric     AU.addRequired<TargetPassConfig>();
718bcb0991SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
728bcb0991SDimitry Andric     AU.addPreserved<LoopInfoWrapperPass>();
738bcb0991SDimitry Andric     AU.setPreservesCFG();
748bcb0991SDimitry Andric   }
758bcb0991SDimitry Andric 
768bcb0991SDimitry Andric   bool runOnLoop(Loop *L, LPPassManager&) override;
778bcb0991SDimitry Andric 
788bcb0991SDimitry Andric private:
798bcb0991SDimitry Andric 
808bcb0991SDimitry Andric   /// Perform the relevant checks on the loop and convert if possible.
818bcb0991SDimitry Andric   bool TryConvert(Value *TripCount);
828bcb0991SDimitry Andric 
838bcb0991SDimitry Andric   /// Return whether this is a vectorized loop, that contains masked
848bcb0991SDimitry Andric   /// load/stores.
858bcb0991SDimitry Andric   bool IsPredicatedVectorLoop();
868bcb0991SDimitry Andric 
878bcb0991SDimitry Andric   /// Compute a value for the total number of elements that the predicated
888bcb0991SDimitry Andric   /// loop will process.
898bcb0991SDimitry Andric   Value *ComputeElements(Value *TripCount, VectorType *VecTy);
908bcb0991SDimitry Andric 
918bcb0991SDimitry Andric   /// Is the icmp that generates an i1 vector, based upon a loop counter
928bcb0991SDimitry Andric   /// and a limit that is defined outside the loop.
938bcb0991SDimitry Andric   bool isTailPredicate(Instruction *Predicate, Value *NumElements);
94*480093f4SDimitry Andric 
95*480093f4SDimitry Andric   /// Insert the intrinsic to represent the effect of tail predication.
96*480093f4SDimitry Andric   void InsertVCTPIntrinsic(Instruction *Predicate,
97*480093f4SDimitry Andric                            DenseMap<Instruction*, Instruction*> &NewPredicates,
98*480093f4SDimitry Andric                            VectorType *VecTy,
99*480093f4SDimitry Andric                            Value *NumElements);
1008bcb0991SDimitry Andric };
1018bcb0991SDimitry Andric 
1028bcb0991SDimitry Andric } // end namespace
1038bcb0991SDimitry Andric 
1048bcb0991SDimitry Andric static bool IsDecrement(Instruction &I) {
1058bcb0991SDimitry Andric   auto *Call = dyn_cast<IntrinsicInst>(&I);
1068bcb0991SDimitry Andric   if (!Call)
1078bcb0991SDimitry Andric     return false;
1088bcb0991SDimitry Andric 
1098bcb0991SDimitry Andric   Intrinsic::ID ID = Call->getIntrinsicID();
1108bcb0991SDimitry Andric   return ID == Intrinsic::loop_decrement_reg;
1118bcb0991SDimitry Andric }
1128bcb0991SDimitry Andric 
1138bcb0991SDimitry Andric static bool IsMasked(Instruction *I) {
1148bcb0991SDimitry Andric   auto *Call = dyn_cast<IntrinsicInst>(I);
1158bcb0991SDimitry Andric   if (!Call)
1168bcb0991SDimitry Andric     return false;
1178bcb0991SDimitry Andric 
1188bcb0991SDimitry Andric   Intrinsic::ID ID = Call->getIntrinsicID();
1198bcb0991SDimitry Andric   // TODO: Support gather/scatter expand/compress operations.
1208bcb0991SDimitry Andric   return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load;
1218bcb0991SDimitry Andric }
1228bcb0991SDimitry Andric 
1238bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
1248bcb0991SDimitry Andric   if (skipLoop(L) || DisableTailPredication)
1258bcb0991SDimitry Andric     return false;
1268bcb0991SDimitry Andric 
1278bcb0991SDimitry Andric   Function &F = *L->getHeader()->getParent();
1288bcb0991SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
1298bcb0991SDimitry Andric   auto &TM = TPC.getTM<TargetMachine>();
1308bcb0991SDimitry Andric   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1318bcb0991SDimitry Andric   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1328bcb0991SDimitry Andric   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1338bcb0991SDimitry Andric   this->L = L;
1348bcb0991SDimitry Andric 
1358bcb0991SDimitry Andric   // The MVE and LOB extensions are combined to enable tail-predication, but
1368bcb0991SDimitry Andric   // there's nothing preventing us from generating VCTP instructions for v8.1m.
1378bcb0991SDimitry Andric   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
138*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
1398bcb0991SDimitry Andric     return false;
1408bcb0991SDimitry Andric   }
1418bcb0991SDimitry Andric 
1428bcb0991SDimitry Andric   BasicBlock *Preheader = L->getLoopPreheader();
1438bcb0991SDimitry Andric   if (!Preheader)
1448bcb0991SDimitry Andric     return false;
1458bcb0991SDimitry Andric 
1468bcb0991SDimitry Andric   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
1478bcb0991SDimitry Andric     for (auto &I : *BB) {
1488bcb0991SDimitry Andric       auto *Call = dyn_cast<IntrinsicInst>(&I);
1498bcb0991SDimitry Andric       if (!Call)
1508bcb0991SDimitry Andric         continue;
1518bcb0991SDimitry Andric 
1528bcb0991SDimitry Andric       Intrinsic::ID ID = Call->getIntrinsicID();
1538bcb0991SDimitry Andric       if (ID == Intrinsic::set_loop_iterations ||
1548bcb0991SDimitry Andric           ID == Intrinsic::test_set_loop_iterations)
1558bcb0991SDimitry Andric         return cast<IntrinsicInst>(&I);
1568bcb0991SDimitry Andric     }
1578bcb0991SDimitry Andric     return nullptr;
1588bcb0991SDimitry Andric   };
1598bcb0991SDimitry Andric 
1608bcb0991SDimitry Andric   // Look for the hardware loop intrinsic that sets the iteration count.
1618bcb0991SDimitry Andric   IntrinsicInst *Setup = FindLoopIterations(Preheader);
1628bcb0991SDimitry Andric 
1638bcb0991SDimitry Andric   // The test.set iteration could live in the pre-preheader.
1648bcb0991SDimitry Andric   if (!Setup) {
1658bcb0991SDimitry Andric     if (!Preheader->getSinglePredecessor())
1668bcb0991SDimitry Andric       return false;
1678bcb0991SDimitry Andric     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
1688bcb0991SDimitry Andric     if (!Setup)
1698bcb0991SDimitry Andric       return false;
1708bcb0991SDimitry Andric   }
1718bcb0991SDimitry Andric 
1728bcb0991SDimitry Andric   // Search for the hardware loop intrinic that decrements the loop counter.
1738bcb0991SDimitry Andric   IntrinsicInst *Decrement = nullptr;
1748bcb0991SDimitry Andric   for (auto *BB : L->getBlocks()) {
1758bcb0991SDimitry Andric     for (auto &I : *BB) {
1768bcb0991SDimitry Andric       if (IsDecrement(I)) {
1778bcb0991SDimitry Andric         Decrement = cast<IntrinsicInst>(&I);
1788bcb0991SDimitry Andric         break;
1798bcb0991SDimitry Andric       }
1808bcb0991SDimitry Andric     }
1818bcb0991SDimitry Andric   }
1828bcb0991SDimitry Andric 
1838bcb0991SDimitry Andric   if (!Decrement)
1848bcb0991SDimitry Andric     return false;
1858bcb0991SDimitry Andric 
186*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
1878bcb0991SDimitry Andric              << *Decrement << "\n");
188*480093f4SDimitry Andric   return TryConvert(Setup->getArgOperand(0));
1898bcb0991SDimitry Andric }
1908bcb0991SDimitry Andric 
1918bcb0991SDimitry Andric bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
1928bcb0991SDimitry Andric   // Look for the following:
1938bcb0991SDimitry Andric 
1948bcb0991SDimitry Andric   // %trip.count.minus.1 = add i32 %N, -1
1958bcb0991SDimitry Andric   // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
1968bcb0991SDimitry Andric   //                                          i32 %trip.count.minus.1, i32 0
1978bcb0991SDimitry Andric   // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
1988bcb0991SDimitry Andric   //                                    <4 x i32> undef,
1998bcb0991SDimitry Andric   //                                    <4 x i32> zeroinitializer
2008bcb0991SDimitry Andric   // ...
2018bcb0991SDimitry Andric   // ...
2028bcb0991SDimitry Andric   // %index = phi i32
2038bcb0991SDimitry Andric   // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
2048bcb0991SDimitry Andric   // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
2058bcb0991SDimitry Andric   //                                  <4 x i32> undef,
2068bcb0991SDimitry Andric   //                                  <4 x i32> zeroinitializer
2078bcb0991SDimitry Andric   // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
2088bcb0991SDimitry Andric   // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
2098bcb0991SDimitry Andric 
2108bcb0991SDimitry Andric   // And return whether V == %pred.
2118bcb0991SDimitry Andric 
2128bcb0991SDimitry Andric   using namespace PatternMatch;
2138bcb0991SDimitry Andric 
2148bcb0991SDimitry Andric   CmpInst::Predicate Pred;
2158bcb0991SDimitry Andric   Instruction *Shuffle = nullptr;
2168bcb0991SDimitry Andric   Instruction *Induction = nullptr;
2178bcb0991SDimitry Andric 
2188bcb0991SDimitry Andric   // The vector icmp
2198bcb0991SDimitry Andric   if (!match(I, m_ICmp(Pred, m_Instruction(Induction),
2208bcb0991SDimitry Andric                        m_Instruction(Shuffle))) ||
221*480093f4SDimitry Andric       Pred != ICmpInst::ICMP_ULE)
2228bcb0991SDimitry Andric     return false;
2238bcb0991SDimitry Andric 
2248bcb0991SDimitry Andric   // First find the stuff outside the loop which is setting up the limit
2258bcb0991SDimitry Andric   // vector....
2268bcb0991SDimitry Andric   // The invariant shuffle that broadcast the limit into a vector.
2278bcb0991SDimitry Andric   Instruction *Insert = nullptr;
2288bcb0991SDimitry Andric   if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
2298bcb0991SDimitry Andric                                       m_Zero())))
2308bcb0991SDimitry Andric     return false;
2318bcb0991SDimitry Andric 
2328bcb0991SDimitry Andric   // Insert the limit into a vector.
2338bcb0991SDimitry Andric   Instruction *BECount = nullptr;
2348bcb0991SDimitry Andric   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount),
2358bcb0991SDimitry Andric                                      m_Zero())))
2368bcb0991SDimitry Andric     return false;
2378bcb0991SDimitry Andric 
2388bcb0991SDimitry Andric   // The limit calculation, backedge count.
2398bcb0991SDimitry Andric   Value *TripCount = nullptr;
2408bcb0991SDimitry Andric   if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
2418bcb0991SDimitry Andric     return false;
2428bcb0991SDimitry Andric 
243*480093f4SDimitry Andric   if (TripCount != NumElements || !L->isLoopInvariant(BECount))
2448bcb0991SDimitry Andric     return false;
2458bcb0991SDimitry Andric 
2468bcb0991SDimitry Andric   // Now back to searching inside the loop body...
2478bcb0991SDimitry Andric   // Find the add with takes the index iv and adds a constant vector to it.
2488bcb0991SDimitry Andric   Instruction *BroadcastSplat = nullptr;
2498bcb0991SDimitry Andric   Constant *Const = nullptr;
2508bcb0991SDimitry Andric   if (!match(Induction, m_Add(m_Instruction(BroadcastSplat),
2518bcb0991SDimitry Andric                               m_Constant(Const))))
2528bcb0991SDimitry Andric    return false;
2538bcb0991SDimitry Andric 
2548bcb0991SDimitry Andric   // Check that we're adding <0, 1, 2, 3...
2558bcb0991SDimitry Andric   if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) {
2568bcb0991SDimitry Andric     for (unsigned i = 0; i < CDS->getNumElements(); ++i) {
2578bcb0991SDimitry Andric       if (CDS->getElementAsInteger(i) != i)
2588bcb0991SDimitry Andric         return false;
2598bcb0991SDimitry Andric     }
2608bcb0991SDimitry Andric   } else
2618bcb0991SDimitry Andric     return false;
2628bcb0991SDimitry Andric 
2638bcb0991SDimitry Andric   // The shuffle which broadcasts the index iv into a vector.
2648bcb0991SDimitry Andric   if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
2658bcb0991SDimitry Andric                                              m_Zero())))
2668bcb0991SDimitry Andric     return false;
2678bcb0991SDimitry Andric 
2688bcb0991SDimitry Andric   // The insert element which initialises a vector with the index iv.
2698bcb0991SDimitry Andric   Instruction *IV = nullptr;
2708bcb0991SDimitry Andric   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero())))
2718bcb0991SDimitry Andric     return false;
2728bcb0991SDimitry Andric 
2738bcb0991SDimitry Andric   // The index iv.
2748bcb0991SDimitry Andric   auto *Phi = dyn_cast<PHINode>(IV);
2758bcb0991SDimitry Andric   if (!Phi)
2768bcb0991SDimitry Andric     return false;
2778bcb0991SDimitry Andric 
2788bcb0991SDimitry Andric   // TODO: Don't think we need to check the entry value.
2798bcb0991SDimitry Andric   Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader());
2808bcb0991SDimitry Andric   if (!match(OnEntry, m_Zero()))
2818bcb0991SDimitry Andric     return false;
2828bcb0991SDimitry Andric 
2838bcb0991SDimitry Andric   Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch());
2848bcb0991SDimitry Andric   unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements();
2858bcb0991SDimitry Andric 
2868bcb0991SDimitry Andric   Instruction *LHS = nullptr;
2878bcb0991SDimitry Andric   if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes))))
2888bcb0991SDimitry Andric     return false;
2898bcb0991SDimitry Andric 
2908bcb0991SDimitry Andric   return LHS == Phi;
2918bcb0991SDimitry Andric }
2928bcb0991SDimitry Andric 
2938bcb0991SDimitry Andric static VectorType* getVectorType(IntrinsicInst *I) {
2948bcb0991SDimitry Andric   unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
2958bcb0991SDimitry Andric   auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType());
2968bcb0991SDimitry Andric   return cast<VectorType>(PtrTy->getElementType());
2978bcb0991SDimitry Andric }
2988bcb0991SDimitry Andric 
2998bcb0991SDimitry Andric bool MVETailPredication::IsPredicatedVectorLoop() {
3008bcb0991SDimitry Andric   // Check that the loop contains at least one masked load/store intrinsic.
3018bcb0991SDimitry Andric   // We only support 'normal' vector instructions - other than masked
3028bcb0991SDimitry Andric   // load/stores.
3038bcb0991SDimitry Andric   for (auto *BB : L->getBlocks()) {
3048bcb0991SDimitry Andric     for (auto &I : *BB) {
3058bcb0991SDimitry Andric       if (IsMasked(&I)) {
3068bcb0991SDimitry Andric         VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I));
3078bcb0991SDimitry Andric         unsigned Lanes = VecTy->getNumElements();
3088bcb0991SDimitry Andric         unsigned ElementWidth = VecTy->getScalarSizeInBits();
3098bcb0991SDimitry Andric         // MVE vectors are 128-bit, but don't support 128 x i1.
3108bcb0991SDimitry Andric         // TODO: Can we support vectors larger than 128-bits?
3118bcb0991SDimitry Andric         unsigned MaxWidth = TTI->getRegisterBitWidth(true);
312*480093f4SDimitry Andric         if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth)
3138bcb0991SDimitry Andric           return false;
3148bcb0991SDimitry Andric         MaskedInsts.push_back(cast<IntrinsicInst>(&I));
3158bcb0991SDimitry Andric       } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) {
3168bcb0991SDimitry Andric         for (auto &U : Int->args()) {
3178bcb0991SDimitry Andric           if (isa<VectorType>(U->getType()))
3188bcb0991SDimitry Andric             return false;
3198bcb0991SDimitry Andric         }
3208bcb0991SDimitry Andric       }
3218bcb0991SDimitry Andric     }
3228bcb0991SDimitry Andric   }
3238bcb0991SDimitry Andric 
3248bcb0991SDimitry Andric   return !MaskedInsts.empty();
3258bcb0991SDimitry Andric }
3268bcb0991SDimitry Andric 
3278bcb0991SDimitry Andric Value* MVETailPredication::ComputeElements(Value *TripCount,
3288bcb0991SDimitry Andric                                            VectorType *VecTy) {
3298bcb0991SDimitry Andric   const SCEV *TripCountSE = SE->getSCEV(TripCount);
3308bcb0991SDimitry Andric   ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()),
3318bcb0991SDimitry Andric                                      VecTy->getNumElements());
3328bcb0991SDimitry Andric 
3338bcb0991SDimitry Andric   if (VF->equalsInt(1))
3348bcb0991SDimitry Andric     return nullptr;
3358bcb0991SDimitry Andric 
3368bcb0991SDimitry Andric   // TODO: Support constant trip counts.
3378bcb0991SDimitry Andric   auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* {
3388bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
3398bcb0991SDimitry Andric       if (Const->getAPInt() != -VF->getValue())
3408bcb0991SDimitry Andric         return nullptr;
3418bcb0991SDimitry Andric     } else
3428bcb0991SDimitry Andric       return nullptr;
3438bcb0991SDimitry Andric     return dyn_cast<SCEVMulExpr>(S->getOperand(1));
3448bcb0991SDimitry Andric   };
3458bcb0991SDimitry Andric 
3468bcb0991SDimitry Andric   auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* {
3478bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
3488bcb0991SDimitry Andric       if (Const->getValue() != VF)
3498bcb0991SDimitry Andric         return nullptr;
3508bcb0991SDimitry Andric     } else
3518bcb0991SDimitry Andric       return nullptr;
3528bcb0991SDimitry Andric     return dyn_cast<SCEVUDivExpr>(S->getOperand(1));
3538bcb0991SDimitry Andric   };
3548bcb0991SDimitry Andric 
3558bcb0991SDimitry Andric   auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* {
3568bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) {
3578bcb0991SDimitry Andric       if (Const->getValue() != VF)
3588bcb0991SDimitry Andric         return nullptr;
3598bcb0991SDimitry Andric     } else
3608bcb0991SDimitry Andric       return nullptr;
3618bcb0991SDimitry Andric 
3628bcb0991SDimitry Andric     if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) {
3638bcb0991SDimitry Andric       if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) {
3648bcb0991SDimitry Andric         if (Const->getAPInt() != (VF->getValue() - 1))
3658bcb0991SDimitry Andric           return nullptr;
3668bcb0991SDimitry Andric       } else
3678bcb0991SDimitry Andric         return nullptr;
3688bcb0991SDimitry Andric 
3698bcb0991SDimitry Andric       return RoundUp->getOperand(1);
3708bcb0991SDimitry Andric     }
3718bcb0991SDimitry Andric     return nullptr;
3728bcb0991SDimitry Andric   };
3738bcb0991SDimitry Andric 
3748bcb0991SDimitry Andric   // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
3758bcb0991SDimitry Andric   // determine the numbers of elements instead? Looks like this is what is used
3768bcb0991SDimitry Andric   // for delinearization, but I'm not sure if it can be applied to the
3778bcb0991SDimitry Andric   // vectorized form - at least not without a bit more work than I feel
3788bcb0991SDimitry Andric   // comfortable with.
3798bcb0991SDimitry Andric 
3808bcb0991SDimitry Andric   // Search for Elems in the following SCEV:
3818bcb0991SDimitry Andric   // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
3828bcb0991SDimitry Andric   const SCEV *Elems = nullptr;
3838bcb0991SDimitry Andric   if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE))
3848bcb0991SDimitry Andric     if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1)))
3858bcb0991SDimitry Andric       if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS()))
3868bcb0991SDimitry Andric         if (auto *Mul = VisitAdd(Add))
3878bcb0991SDimitry Andric           if (auto *Div = VisitMul(Mul))
3888bcb0991SDimitry Andric             if (auto *Res = VisitDiv(Div))
3898bcb0991SDimitry Andric               Elems = Res;
3908bcb0991SDimitry Andric 
3918bcb0991SDimitry Andric   if (!Elems)
3928bcb0991SDimitry Andric     return nullptr;
3938bcb0991SDimitry Andric 
3948bcb0991SDimitry Andric   Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
3958bcb0991SDimitry Andric   if (!isSafeToExpandAt(Elems, InsertPt, *SE))
3968bcb0991SDimitry Andric     return nullptr;
3978bcb0991SDimitry Andric 
3988bcb0991SDimitry Andric   auto DL = L->getHeader()->getModule()->getDataLayout();
3998bcb0991SDimitry Andric   SCEVExpander Expander(*SE, DL, "elements");
4008bcb0991SDimitry Andric   return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
4018bcb0991SDimitry Andric }
4028bcb0991SDimitry Andric 
4038bcb0991SDimitry Andric // Look through the exit block to see whether there's a duplicate predicate
4048bcb0991SDimitry Andric // instruction. This can happen when we need to perform a select on values
4058bcb0991SDimitry Andric // from the last and previous iteration. Instead of doing a straight
4068bcb0991SDimitry Andric // replacement of that predicate with the vctp, clone the vctp and place it
4078bcb0991SDimitry Andric // in the block. This means that the VPR doesn't have to be live into the
4088bcb0991SDimitry Andric // exit block which should make it easier to convert this loop into a proper
4098bcb0991SDimitry Andric // tail predicated loop.
4108bcb0991SDimitry Andric static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
4118bcb0991SDimitry Andric                     SetVector<Instruction*> &MaybeDead, Loop *L) {
412*480093f4SDimitry Andric   BasicBlock *Exit = L->getUniqueExitBlock();
413*480093f4SDimitry Andric   if (!Exit) {
414*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n");
415*480093f4SDimitry Andric     return;
416*480093f4SDimitry Andric   }
417*480093f4SDimitry Andric 
4188bcb0991SDimitry Andric   for (auto &Pair : NewPredicates) {
4198bcb0991SDimitry Andric     Instruction *OldPred = Pair.first;
4208bcb0991SDimitry Andric     Instruction *NewPred = Pair.second;
4218bcb0991SDimitry Andric 
4228bcb0991SDimitry Andric     for (auto &I : *Exit) {
4238bcb0991SDimitry Andric       if (I.isSameOperationAs(OldPred)) {
4248bcb0991SDimitry Andric         Instruction *PredClone = NewPred->clone();
4258bcb0991SDimitry Andric         PredClone->insertBefore(&I);
4268bcb0991SDimitry Andric         I.replaceAllUsesWith(PredClone);
4278bcb0991SDimitry Andric         MaybeDead.insert(&I);
428*480093f4SDimitry Andric         LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump();
429*480093f4SDimitry Andric                    dbgs() << "ARM TP: with:      "; PredClone->dump());
4308bcb0991SDimitry Andric         break;
4318bcb0991SDimitry Andric       }
4328bcb0991SDimitry Andric     }
4338bcb0991SDimitry Andric   }
4348bcb0991SDimitry Andric 
4358bcb0991SDimitry Andric   // Drop references and add operands to check for dead.
4368bcb0991SDimitry Andric   SmallPtrSet<Instruction*, 4> Dead;
4378bcb0991SDimitry Andric   while (!MaybeDead.empty()) {
4388bcb0991SDimitry Andric     auto *I = MaybeDead.front();
4398bcb0991SDimitry Andric     MaybeDead.remove(I);
4408bcb0991SDimitry Andric     if (I->hasNUsesOrMore(1))
4418bcb0991SDimitry Andric       continue;
4428bcb0991SDimitry Andric 
4438bcb0991SDimitry Andric     for (auto &U : I->operands()) {
4448bcb0991SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(U))
4458bcb0991SDimitry Andric         MaybeDead.insert(OpI);
4468bcb0991SDimitry Andric     }
4478bcb0991SDimitry Andric     I->dropAllReferences();
4488bcb0991SDimitry Andric     Dead.insert(I);
4498bcb0991SDimitry Andric   }
4508bcb0991SDimitry Andric 
451*480093f4SDimitry Andric   for (auto *I : Dead) {
452*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump());
4538bcb0991SDimitry Andric     I->eraseFromParent();
454*480093f4SDimitry Andric   }
4558bcb0991SDimitry Andric 
4568bcb0991SDimitry Andric   for (auto I : L->blocks())
4578bcb0991SDimitry Andric     DeleteDeadPHIs(I);
4588bcb0991SDimitry Andric }
4598bcb0991SDimitry Andric 
460*480093f4SDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,
461*480093f4SDimitry Andric     DenseMap<Instruction*, Instruction*> &NewPredicates,
462*480093f4SDimitry Andric     VectorType *VecTy, Value *NumElements) {
463*480093f4SDimitry Andric   IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
464*480093f4SDimitry Andric   Module *M = L->getHeader()->getModule();
465*480093f4SDimitry Andric   Type *Ty = IntegerType::get(M->getContext(), 32);
4668bcb0991SDimitry Andric 
467*480093f4SDimitry Andric   // Insert a phi to count the number of elements processed by the loop.
468*480093f4SDimitry Andric   PHINode *Processed = Builder.CreatePHI(Ty, 2);
469*480093f4SDimitry Andric   Processed->addIncoming(NumElements, L->getLoopPreheader());
470*480093f4SDimitry Andric 
471*480093f4SDimitry Andric   // Insert the intrinsic to represent the effect of tail predication.
472*480093f4SDimitry Andric   Builder.SetInsertPoint(cast<Instruction>(Predicate));
473*480093f4SDimitry Andric   ConstantInt *Factor =
474*480093f4SDimitry Andric     ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements());
475*480093f4SDimitry Andric 
476*480093f4SDimitry Andric   Intrinsic::ID VCTPID;
477*480093f4SDimitry Andric   switch (VecTy->getNumElements()) {
478*480093f4SDimitry Andric   default:
479*480093f4SDimitry Andric     llvm_unreachable("unexpected number of lanes");
480*480093f4SDimitry Andric   case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
481*480093f4SDimitry Andric   case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
482*480093f4SDimitry Andric   case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
483*480093f4SDimitry Andric 
484*480093f4SDimitry Andric     // FIXME: vctp64 currently not supported because the predicate
485*480093f4SDimitry Andric     // vector wants to be <2 x i1>, but v2i1 is not a legal MVE
486*480093f4SDimitry Andric     // type, so problems happen at isel time.
487*480093f4SDimitry Andric     // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics
488*480093f4SDimitry Andric     // purposes, but takes a v4i1 instead of a v2i1.
489*480093f4SDimitry Andric   }
490*480093f4SDimitry Andric   Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
491*480093f4SDimitry Andric   Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
492*480093f4SDimitry Andric   Predicate->replaceAllUsesWith(TailPredicate);
493*480093f4SDimitry Andric   NewPredicates[Predicate] = cast<Instruction>(TailPredicate);
494*480093f4SDimitry Andric 
495*480093f4SDimitry Andric   // Add the incoming value to the new phi.
496*480093f4SDimitry Andric   // TODO: This add likely already exists in the loop.
497*480093f4SDimitry Andric   Value *Remaining = Builder.CreateSub(Processed, Factor);
498*480093f4SDimitry Andric   Processed->addIncoming(Remaining, L->getLoopLatch());
499*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
500*480093f4SDimitry Andric              << *Processed << "\n"
501*480093f4SDimitry Andric              << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n");
502*480093f4SDimitry Andric }
503*480093f4SDimitry Andric 
504*480093f4SDimitry Andric bool MVETailPredication::TryConvert(Value *TripCount) {
505*480093f4SDimitry Andric   if (!IsPredicatedVectorLoop()) {
506*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop");
507*480093f4SDimitry Andric     return false;
508*480093f4SDimitry Andric   }
509*480093f4SDimitry Andric 
510*480093f4SDimitry Andric   LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
5118bcb0991SDimitry Andric 
5128bcb0991SDimitry Andric   // Walk through the masked intrinsics and try to find whether the predicate
5138bcb0991SDimitry Andric   // operand is generated from an induction variable.
5148bcb0991SDimitry Andric   SetVector<Instruction*> Predicates;
5158bcb0991SDimitry Andric   DenseMap<Instruction*, Instruction*> NewPredicates;
5168bcb0991SDimitry Andric 
5178bcb0991SDimitry Andric   for (auto *I : MaskedInsts) {
5188bcb0991SDimitry Andric     Intrinsic::ID ID = I->getIntrinsicID();
5198bcb0991SDimitry Andric     unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3;
5208bcb0991SDimitry Andric     auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp));
5218bcb0991SDimitry Andric     if (!Predicate || Predicates.count(Predicate))
5228bcb0991SDimitry Andric       continue;
5238bcb0991SDimitry Andric 
5248bcb0991SDimitry Andric     VectorType *VecTy = getVectorType(I);
5258bcb0991SDimitry Andric     Value *NumElements = ComputeElements(TripCount, VecTy);
5268bcb0991SDimitry Andric     if (!NumElements)
5278bcb0991SDimitry Andric       continue;
5288bcb0991SDimitry Andric 
5298bcb0991SDimitry Andric     if (!isTailPredicate(Predicate, NumElements)) {
530*480093f4SDimitry Andric       LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n");
5318bcb0991SDimitry Andric       continue;
5328bcb0991SDimitry Andric     }
5338bcb0991SDimitry Andric 
534*480093f4SDimitry Andric     LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n");
5358bcb0991SDimitry Andric     Predicates.insert(Predicate);
5368bcb0991SDimitry Andric 
537*480093f4SDimitry Andric     InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements);
5388bcb0991SDimitry Andric   }
5398bcb0991SDimitry Andric 
5408bcb0991SDimitry Andric   // Now clean up.
5418bcb0991SDimitry Andric   Cleanup(NewPredicates, Predicates, L);
5428bcb0991SDimitry Andric   return true;
5438bcb0991SDimitry Andric }
5448bcb0991SDimitry Andric 
5458bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() {
5468bcb0991SDimitry Andric   return new MVETailPredication();
5478bcb0991SDimitry Andric }
5488bcb0991SDimitry Andric 
5498bcb0991SDimitry Andric char MVETailPredication::ID = 0;
5508bcb0991SDimitry Andric 
5518bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
5528bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
553