18bcb0991SDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 28bcb0991SDimitry Andric // 38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68bcb0991SDimitry Andric // 78bcb0991SDimitry Andric //===----------------------------------------------------------------------===// 88bcb0991SDimitry Andric // 98bcb0991SDimitry Andric /// \file 108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 118bcb0991SDimitry Andric /// branches to help accelerate DSP applications. These two extensions can be 128bcb0991SDimitry Andric /// combined to provide implicit vector predication within a low-overhead loop. 138bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the 148bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is 158bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are 168bcb0991SDimitry Andric /// predicated upon the iteration counter. This pass looks at these predicated 178bcb0991SDimitry Andric /// vector loops, that are targets for low-overhead loops, and prepares it for 188bcb0991SDimitry Andric /// code generation. Once the vectorizer has produced a masked loop, there's a 198bcb0991SDimitry Andric /// couple of final forms: 208bcb0991SDimitry Andric /// - A tail-predicated loop, with implicit predication. 218bcb0991SDimitry Andric /// - A loop containing multiple VCPT instructions, predicating multiple VPT 228bcb0991SDimitry Andric /// blocks of instructions operating on different vector types. 23*480093f4SDimitry Andric /// 24*480093f4SDimitry Andric /// This pass inserts the inserts the VCTP intrinsic to represent the effect of 25*480093f4SDimitry Andric /// tail predication. This will be picked up by the ARM Low-overhead loop pass, 26*480093f4SDimitry Andric /// which performs the final transformation to a DLSTP or WLSTP tail-predicated 27*480093f4SDimitry Andric /// loop. 288bcb0991SDimitry Andric 29*480093f4SDimitry Andric #include "ARM.h" 30*480093f4SDimitry Andric #include "ARMSubtarget.h" 318bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 328bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h" 338bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 348bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpander.h" 358bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h" 368bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 378bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 388bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h" 39*480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 40*480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 418bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h" 428bcb0991SDimitry Andric #include "llvm/Support/Debug.h" 438bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 448bcb0991SDimitry Andric 458bcb0991SDimitry Andric using namespace llvm; 468bcb0991SDimitry Andric 478bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication" 488bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication" 498bcb0991SDimitry Andric 50*480093f4SDimitry Andric cl::opt<bool> 518bcb0991SDimitry Andric DisableTailPredication("disable-mve-tail-predication", cl::Hidden, 528bcb0991SDimitry Andric cl::init(true), 538bcb0991SDimitry Andric cl::desc("Disable MVE Tail Predication")); 548bcb0991SDimitry Andric namespace { 558bcb0991SDimitry Andric 568bcb0991SDimitry Andric class MVETailPredication : public LoopPass { 578bcb0991SDimitry Andric SmallVector<IntrinsicInst*, 4> MaskedInsts; 588bcb0991SDimitry Andric Loop *L = nullptr; 598bcb0991SDimitry Andric ScalarEvolution *SE = nullptr; 608bcb0991SDimitry Andric TargetTransformInfo *TTI = nullptr; 618bcb0991SDimitry Andric 628bcb0991SDimitry Andric public: 638bcb0991SDimitry Andric static char ID; 648bcb0991SDimitry Andric 658bcb0991SDimitry Andric MVETailPredication() : LoopPass(ID) { } 668bcb0991SDimitry Andric 678bcb0991SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 688bcb0991SDimitry Andric AU.addRequired<ScalarEvolutionWrapperPass>(); 698bcb0991SDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 708bcb0991SDimitry Andric AU.addRequired<TargetPassConfig>(); 718bcb0991SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 728bcb0991SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 738bcb0991SDimitry Andric AU.setPreservesCFG(); 748bcb0991SDimitry Andric } 758bcb0991SDimitry Andric 768bcb0991SDimitry Andric bool runOnLoop(Loop *L, LPPassManager&) override; 778bcb0991SDimitry Andric 788bcb0991SDimitry Andric private: 798bcb0991SDimitry Andric 808bcb0991SDimitry Andric /// Perform the relevant checks on the loop and convert if possible. 818bcb0991SDimitry Andric bool TryConvert(Value *TripCount); 828bcb0991SDimitry Andric 838bcb0991SDimitry Andric /// Return whether this is a vectorized loop, that contains masked 848bcb0991SDimitry Andric /// load/stores. 858bcb0991SDimitry Andric bool IsPredicatedVectorLoop(); 868bcb0991SDimitry Andric 878bcb0991SDimitry Andric /// Compute a value for the total number of elements that the predicated 888bcb0991SDimitry Andric /// loop will process. 898bcb0991SDimitry Andric Value *ComputeElements(Value *TripCount, VectorType *VecTy); 908bcb0991SDimitry Andric 918bcb0991SDimitry Andric /// Is the icmp that generates an i1 vector, based upon a loop counter 928bcb0991SDimitry Andric /// and a limit that is defined outside the loop. 938bcb0991SDimitry Andric bool isTailPredicate(Instruction *Predicate, Value *NumElements); 94*480093f4SDimitry Andric 95*480093f4SDimitry Andric /// Insert the intrinsic to represent the effect of tail predication. 96*480093f4SDimitry Andric void InsertVCTPIntrinsic(Instruction *Predicate, 97*480093f4SDimitry Andric DenseMap<Instruction*, Instruction*> &NewPredicates, 98*480093f4SDimitry Andric VectorType *VecTy, 99*480093f4SDimitry Andric Value *NumElements); 1008bcb0991SDimitry Andric }; 1018bcb0991SDimitry Andric 1028bcb0991SDimitry Andric } // end namespace 1038bcb0991SDimitry Andric 1048bcb0991SDimitry Andric static bool IsDecrement(Instruction &I) { 1058bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 1068bcb0991SDimitry Andric if (!Call) 1078bcb0991SDimitry Andric return false; 1088bcb0991SDimitry Andric 1098bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1108bcb0991SDimitry Andric return ID == Intrinsic::loop_decrement_reg; 1118bcb0991SDimitry Andric } 1128bcb0991SDimitry Andric 1138bcb0991SDimitry Andric static bool IsMasked(Instruction *I) { 1148bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(I); 1158bcb0991SDimitry Andric if (!Call) 1168bcb0991SDimitry Andric return false; 1178bcb0991SDimitry Andric 1188bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1198bcb0991SDimitry Andric // TODO: Support gather/scatter expand/compress operations. 1208bcb0991SDimitry Andric return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 1218bcb0991SDimitry Andric } 1228bcb0991SDimitry Andric 1238bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 1248bcb0991SDimitry Andric if (skipLoop(L) || DisableTailPredication) 1258bcb0991SDimitry Andric return false; 1268bcb0991SDimitry Andric 1278bcb0991SDimitry Andric Function &F = *L->getHeader()->getParent(); 1288bcb0991SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 1298bcb0991SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 1308bcb0991SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1318bcb0991SDimitry Andric TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1328bcb0991SDimitry Andric SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 1338bcb0991SDimitry Andric this->L = L; 1348bcb0991SDimitry Andric 1358bcb0991SDimitry Andric // The MVE and LOB extensions are combined to enable tail-predication, but 1368bcb0991SDimitry Andric // there's nothing preventing us from generating VCTP instructions for v8.1m. 1378bcb0991SDimitry Andric if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 138*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 1398bcb0991SDimitry Andric return false; 1408bcb0991SDimitry Andric } 1418bcb0991SDimitry Andric 1428bcb0991SDimitry Andric BasicBlock *Preheader = L->getLoopPreheader(); 1438bcb0991SDimitry Andric if (!Preheader) 1448bcb0991SDimitry Andric return false; 1458bcb0991SDimitry Andric 1468bcb0991SDimitry Andric auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 1478bcb0991SDimitry Andric for (auto &I : *BB) { 1488bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 1498bcb0991SDimitry Andric if (!Call) 1508bcb0991SDimitry Andric continue; 1518bcb0991SDimitry Andric 1528bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1538bcb0991SDimitry Andric if (ID == Intrinsic::set_loop_iterations || 1548bcb0991SDimitry Andric ID == Intrinsic::test_set_loop_iterations) 1558bcb0991SDimitry Andric return cast<IntrinsicInst>(&I); 1568bcb0991SDimitry Andric } 1578bcb0991SDimitry Andric return nullptr; 1588bcb0991SDimitry Andric }; 1598bcb0991SDimitry Andric 1608bcb0991SDimitry Andric // Look for the hardware loop intrinsic that sets the iteration count. 1618bcb0991SDimitry Andric IntrinsicInst *Setup = FindLoopIterations(Preheader); 1628bcb0991SDimitry Andric 1638bcb0991SDimitry Andric // The test.set iteration could live in the pre-preheader. 1648bcb0991SDimitry Andric if (!Setup) { 1658bcb0991SDimitry Andric if (!Preheader->getSinglePredecessor()) 1668bcb0991SDimitry Andric return false; 1678bcb0991SDimitry Andric Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 1688bcb0991SDimitry Andric if (!Setup) 1698bcb0991SDimitry Andric return false; 1708bcb0991SDimitry Andric } 1718bcb0991SDimitry Andric 1728bcb0991SDimitry Andric // Search for the hardware loop intrinic that decrements the loop counter. 1738bcb0991SDimitry Andric IntrinsicInst *Decrement = nullptr; 1748bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 1758bcb0991SDimitry Andric for (auto &I : *BB) { 1768bcb0991SDimitry Andric if (IsDecrement(I)) { 1778bcb0991SDimitry Andric Decrement = cast<IntrinsicInst>(&I); 1788bcb0991SDimitry Andric break; 1798bcb0991SDimitry Andric } 1808bcb0991SDimitry Andric } 1818bcb0991SDimitry Andric } 1828bcb0991SDimitry Andric 1838bcb0991SDimitry Andric if (!Decrement) 1848bcb0991SDimitry Andric return false; 1858bcb0991SDimitry Andric 186*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" 1878bcb0991SDimitry Andric << *Decrement << "\n"); 188*480093f4SDimitry Andric return TryConvert(Setup->getArgOperand(0)); 1898bcb0991SDimitry Andric } 1908bcb0991SDimitry Andric 1918bcb0991SDimitry Andric bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { 1928bcb0991SDimitry Andric // Look for the following: 1938bcb0991SDimitry Andric 1948bcb0991SDimitry Andric // %trip.count.minus.1 = add i32 %N, -1 1958bcb0991SDimitry Andric // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 1968bcb0991SDimitry Andric // i32 %trip.count.minus.1, i32 0 1978bcb0991SDimitry Andric // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 1988bcb0991SDimitry Andric // <4 x i32> undef, 1998bcb0991SDimitry Andric // <4 x i32> zeroinitializer 2008bcb0991SDimitry Andric // ... 2018bcb0991SDimitry Andric // ... 2028bcb0991SDimitry Andric // %index = phi i32 2038bcb0991SDimitry Andric // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 2048bcb0991SDimitry Andric // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 2058bcb0991SDimitry Andric // <4 x i32> undef, 2068bcb0991SDimitry Andric // <4 x i32> zeroinitializer 2078bcb0991SDimitry Andric // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 2088bcb0991SDimitry Andric // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 2098bcb0991SDimitry Andric 2108bcb0991SDimitry Andric // And return whether V == %pred. 2118bcb0991SDimitry Andric 2128bcb0991SDimitry Andric using namespace PatternMatch; 2138bcb0991SDimitry Andric 2148bcb0991SDimitry Andric CmpInst::Predicate Pred; 2158bcb0991SDimitry Andric Instruction *Shuffle = nullptr; 2168bcb0991SDimitry Andric Instruction *Induction = nullptr; 2178bcb0991SDimitry Andric 2188bcb0991SDimitry Andric // The vector icmp 2198bcb0991SDimitry Andric if (!match(I, m_ICmp(Pred, m_Instruction(Induction), 2208bcb0991SDimitry Andric m_Instruction(Shuffle))) || 221*480093f4SDimitry Andric Pred != ICmpInst::ICMP_ULE) 2228bcb0991SDimitry Andric return false; 2238bcb0991SDimitry Andric 2248bcb0991SDimitry Andric // First find the stuff outside the loop which is setting up the limit 2258bcb0991SDimitry Andric // vector.... 2268bcb0991SDimitry Andric // The invariant shuffle that broadcast the limit into a vector. 2278bcb0991SDimitry Andric Instruction *Insert = nullptr; 2288bcb0991SDimitry Andric if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 2298bcb0991SDimitry Andric m_Zero()))) 2308bcb0991SDimitry Andric return false; 2318bcb0991SDimitry Andric 2328bcb0991SDimitry Andric // Insert the limit into a vector. 2338bcb0991SDimitry Andric Instruction *BECount = nullptr; 2348bcb0991SDimitry Andric if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount), 2358bcb0991SDimitry Andric m_Zero()))) 2368bcb0991SDimitry Andric return false; 2378bcb0991SDimitry Andric 2388bcb0991SDimitry Andric // The limit calculation, backedge count. 2398bcb0991SDimitry Andric Value *TripCount = nullptr; 2408bcb0991SDimitry Andric if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 2418bcb0991SDimitry Andric return false; 2428bcb0991SDimitry Andric 243*480093f4SDimitry Andric if (TripCount != NumElements || !L->isLoopInvariant(BECount)) 2448bcb0991SDimitry Andric return false; 2458bcb0991SDimitry Andric 2468bcb0991SDimitry Andric // Now back to searching inside the loop body... 2478bcb0991SDimitry Andric // Find the add with takes the index iv and adds a constant vector to it. 2488bcb0991SDimitry Andric Instruction *BroadcastSplat = nullptr; 2498bcb0991SDimitry Andric Constant *Const = nullptr; 2508bcb0991SDimitry Andric if (!match(Induction, m_Add(m_Instruction(BroadcastSplat), 2518bcb0991SDimitry Andric m_Constant(Const)))) 2528bcb0991SDimitry Andric return false; 2538bcb0991SDimitry Andric 2548bcb0991SDimitry Andric // Check that we're adding <0, 1, 2, 3... 2558bcb0991SDimitry Andric if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 2568bcb0991SDimitry Andric for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 2578bcb0991SDimitry Andric if (CDS->getElementAsInteger(i) != i) 2588bcb0991SDimitry Andric return false; 2598bcb0991SDimitry Andric } 2608bcb0991SDimitry Andric } else 2618bcb0991SDimitry Andric return false; 2628bcb0991SDimitry Andric 2638bcb0991SDimitry Andric // The shuffle which broadcasts the index iv into a vector. 2648bcb0991SDimitry Andric if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 2658bcb0991SDimitry Andric m_Zero()))) 2668bcb0991SDimitry Andric return false; 2678bcb0991SDimitry Andric 2688bcb0991SDimitry Andric // The insert element which initialises a vector with the index iv. 2698bcb0991SDimitry Andric Instruction *IV = nullptr; 2708bcb0991SDimitry Andric if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 2718bcb0991SDimitry Andric return false; 2728bcb0991SDimitry Andric 2738bcb0991SDimitry Andric // The index iv. 2748bcb0991SDimitry Andric auto *Phi = dyn_cast<PHINode>(IV); 2758bcb0991SDimitry Andric if (!Phi) 2768bcb0991SDimitry Andric return false; 2778bcb0991SDimitry Andric 2788bcb0991SDimitry Andric // TODO: Don't think we need to check the entry value. 2798bcb0991SDimitry Andric Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 2808bcb0991SDimitry Andric if (!match(OnEntry, m_Zero())) 2818bcb0991SDimitry Andric return false; 2828bcb0991SDimitry Andric 2838bcb0991SDimitry Andric Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 2848bcb0991SDimitry Andric unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 2858bcb0991SDimitry Andric 2868bcb0991SDimitry Andric Instruction *LHS = nullptr; 2878bcb0991SDimitry Andric if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 2888bcb0991SDimitry Andric return false; 2898bcb0991SDimitry Andric 2908bcb0991SDimitry Andric return LHS == Phi; 2918bcb0991SDimitry Andric } 2928bcb0991SDimitry Andric 2938bcb0991SDimitry Andric static VectorType* getVectorType(IntrinsicInst *I) { 2948bcb0991SDimitry Andric unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 2958bcb0991SDimitry Andric auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 2968bcb0991SDimitry Andric return cast<VectorType>(PtrTy->getElementType()); 2978bcb0991SDimitry Andric } 2988bcb0991SDimitry Andric 2998bcb0991SDimitry Andric bool MVETailPredication::IsPredicatedVectorLoop() { 3008bcb0991SDimitry Andric // Check that the loop contains at least one masked load/store intrinsic. 3018bcb0991SDimitry Andric // We only support 'normal' vector instructions - other than masked 3028bcb0991SDimitry Andric // load/stores. 3038bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 3048bcb0991SDimitry Andric for (auto &I : *BB) { 3058bcb0991SDimitry Andric if (IsMasked(&I)) { 3068bcb0991SDimitry Andric VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 3078bcb0991SDimitry Andric unsigned Lanes = VecTy->getNumElements(); 3088bcb0991SDimitry Andric unsigned ElementWidth = VecTy->getScalarSizeInBits(); 3098bcb0991SDimitry Andric // MVE vectors are 128-bit, but don't support 128 x i1. 3108bcb0991SDimitry Andric // TODO: Can we support vectors larger than 128-bits? 3118bcb0991SDimitry Andric unsigned MaxWidth = TTI->getRegisterBitWidth(true); 312*480093f4SDimitry Andric if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth) 3138bcb0991SDimitry Andric return false; 3148bcb0991SDimitry Andric MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 3158bcb0991SDimitry Andric } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 3168bcb0991SDimitry Andric for (auto &U : Int->args()) { 3178bcb0991SDimitry Andric if (isa<VectorType>(U->getType())) 3188bcb0991SDimitry Andric return false; 3198bcb0991SDimitry Andric } 3208bcb0991SDimitry Andric } 3218bcb0991SDimitry Andric } 3228bcb0991SDimitry Andric } 3238bcb0991SDimitry Andric 3248bcb0991SDimitry Andric return !MaskedInsts.empty(); 3258bcb0991SDimitry Andric } 3268bcb0991SDimitry Andric 3278bcb0991SDimitry Andric Value* MVETailPredication::ComputeElements(Value *TripCount, 3288bcb0991SDimitry Andric VectorType *VecTy) { 3298bcb0991SDimitry Andric const SCEV *TripCountSE = SE->getSCEV(TripCount); 3308bcb0991SDimitry Andric ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()), 3318bcb0991SDimitry Andric VecTy->getNumElements()); 3328bcb0991SDimitry Andric 3338bcb0991SDimitry Andric if (VF->equalsInt(1)) 3348bcb0991SDimitry Andric return nullptr; 3358bcb0991SDimitry Andric 3368bcb0991SDimitry Andric // TODO: Support constant trip counts. 3378bcb0991SDimitry Andric auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { 3388bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 3398bcb0991SDimitry Andric if (Const->getAPInt() != -VF->getValue()) 3408bcb0991SDimitry Andric return nullptr; 3418bcb0991SDimitry Andric } else 3428bcb0991SDimitry Andric return nullptr; 3438bcb0991SDimitry Andric return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 3448bcb0991SDimitry Andric }; 3458bcb0991SDimitry Andric 3468bcb0991SDimitry Andric auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { 3478bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 3488bcb0991SDimitry Andric if (Const->getValue() != VF) 3498bcb0991SDimitry Andric return nullptr; 3508bcb0991SDimitry Andric } else 3518bcb0991SDimitry Andric return nullptr; 3528bcb0991SDimitry Andric return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 3538bcb0991SDimitry Andric }; 3548bcb0991SDimitry Andric 3558bcb0991SDimitry Andric auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { 3568bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 3578bcb0991SDimitry Andric if (Const->getValue() != VF) 3588bcb0991SDimitry Andric return nullptr; 3598bcb0991SDimitry Andric } else 3608bcb0991SDimitry Andric return nullptr; 3618bcb0991SDimitry Andric 3628bcb0991SDimitry Andric if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 3638bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 3648bcb0991SDimitry Andric if (Const->getAPInt() != (VF->getValue() - 1)) 3658bcb0991SDimitry Andric return nullptr; 3668bcb0991SDimitry Andric } else 3678bcb0991SDimitry Andric return nullptr; 3688bcb0991SDimitry Andric 3698bcb0991SDimitry Andric return RoundUp->getOperand(1); 3708bcb0991SDimitry Andric } 3718bcb0991SDimitry Andric return nullptr; 3728bcb0991SDimitry Andric }; 3738bcb0991SDimitry Andric 3748bcb0991SDimitry Andric // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 3758bcb0991SDimitry Andric // determine the numbers of elements instead? Looks like this is what is used 3768bcb0991SDimitry Andric // for delinearization, but I'm not sure if it can be applied to the 3778bcb0991SDimitry Andric // vectorized form - at least not without a bit more work than I feel 3788bcb0991SDimitry Andric // comfortable with. 3798bcb0991SDimitry Andric 3808bcb0991SDimitry Andric // Search for Elems in the following SCEV: 3818bcb0991SDimitry Andric // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 3828bcb0991SDimitry Andric const SCEV *Elems = nullptr; 3838bcb0991SDimitry Andric if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 3848bcb0991SDimitry Andric if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 3858bcb0991SDimitry Andric if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 3868bcb0991SDimitry Andric if (auto *Mul = VisitAdd(Add)) 3878bcb0991SDimitry Andric if (auto *Div = VisitMul(Mul)) 3888bcb0991SDimitry Andric if (auto *Res = VisitDiv(Div)) 3898bcb0991SDimitry Andric Elems = Res; 3908bcb0991SDimitry Andric 3918bcb0991SDimitry Andric if (!Elems) 3928bcb0991SDimitry Andric return nullptr; 3938bcb0991SDimitry Andric 3948bcb0991SDimitry Andric Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 3958bcb0991SDimitry Andric if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 3968bcb0991SDimitry Andric return nullptr; 3978bcb0991SDimitry Andric 3988bcb0991SDimitry Andric auto DL = L->getHeader()->getModule()->getDataLayout(); 3998bcb0991SDimitry Andric SCEVExpander Expander(*SE, DL, "elements"); 4008bcb0991SDimitry Andric return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 4018bcb0991SDimitry Andric } 4028bcb0991SDimitry Andric 4038bcb0991SDimitry Andric // Look through the exit block to see whether there's a duplicate predicate 4048bcb0991SDimitry Andric // instruction. This can happen when we need to perform a select on values 4058bcb0991SDimitry Andric // from the last and previous iteration. Instead of doing a straight 4068bcb0991SDimitry Andric // replacement of that predicate with the vctp, clone the vctp and place it 4078bcb0991SDimitry Andric // in the block. This means that the VPR doesn't have to be live into the 4088bcb0991SDimitry Andric // exit block which should make it easier to convert this loop into a proper 4098bcb0991SDimitry Andric // tail predicated loop. 4108bcb0991SDimitry Andric static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 4118bcb0991SDimitry Andric SetVector<Instruction*> &MaybeDead, Loop *L) { 412*480093f4SDimitry Andric BasicBlock *Exit = L->getUniqueExitBlock(); 413*480093f4SDimitry Andric if (!Exit) { 414*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n"); 415*480093f4SDimitry Andric return; 416*480093f4SDimitry Andric } 417*480093f4SDimitry Andric 4188bcb0991SDimitry Andric for (auto &Pair : NewPredicates) { 4198bcb0991SDimitry Andric Instruction *OldPred = Pair.first; 4208bcb0991SDimitry Andric Instruction *NewPred = Pair.second; 4218bcb0991SDimitry Andric 4228bcb0991SDimitry Andric for (auto &I : *Exit) { 4238bcb0991SDimitry Andric if (I.isSameOperationAs(OldPred)) { 4248bcb0991SDimitry Andric Instruction *PredClone = NewPred->clone(); 4258bcb0991SDimitry Andric PredClone->insertBefore(&I); 4268bcb0991SDimitry Andric I.replaceAllUsesWith(PredClone); 4278bcb0991SDimitry Andric MaybeDead.insert(&I); 428*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump(); 429*480093f4SDimitry Andric dbgs() << "ARM TP: with: "; PredClone->dump()); 4308bcb0991SDimitry Andric break; 4318bcb0991SDimitry Andric } 4328bcb0991SDimitry Andric } 4338bcb0991SDimitry Andric } 4348bcb0991SDimitry Andric 4358bcb0991SDimitry Andric // Drop references and add operands to check for dead. 4368bcb0991SDimitry Andric SmallPtrSet<Instruction*, 4> Dead; 4378bcb0991SDimitry Andric while (!MaybeDead.empty()) { 4388bcb0991SDimitry Andric auto *I = MaybeDead.front(); 4398bcb0991SDimitry Andric MaybeDead.remove(I); 4408bcb0991SDimitry Andric if (I->hasNUsesOrMore(1)) 4418bcb0991SDimitry Andric continue; 4428bcb0991SDimitry Andric 4438bcb0991SDimitry Andric for (auto &U : I->operands()) { 4448bcb0991SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(U)) 4458bcb0991SDimitry Andric MaybeDead.insert(OpI); 4468bcb0991SDimitry Andric } 4478bcb0991SDimitry Andric I->dropAllReferences(); 4488bcb0991SDimitry Andric Dead.insert(I); 4498bcb0991SDimitry Andric } 4508bcb0991SDimitry Andric 451*480093f4SDimitry Andric for (auto *I : Dead) { 452*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump()); 4538bcb0991SDimitry Andric I->eraseFromParent(); 454*480093f4SDimitry Andric } 4558bcb0991SDimitry Andric 4568bcb0991SDimitry Andric for (auto I : L->blocks()) 4578bcb0991SDimitry Andric DeleteDeadPHIs(I); 4588bcb0991SDimitry Andric } 4598bcb0991SDimitry Andric 460*480093f4SDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate, 461*480093f4SDimitry Andric DenseMap<Instruction*, Instruction*> &NewPredicates, 462*480093f4SDimitry Andric VectorType *VecTy, Value *NumElements) { 463*480093f4SDimitry Andric IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 464*480093f4SDimitry Andric Module *M = L->getHeader()->getModule(); 465*480093f4SDimitry Andric Type *Ty = IntegerType::get(M->getContext(), 32); 4668bcb0991SDimitry Andric 467*480093f4SDimitry Andric // Insert a phi to count the number of elements processed by the loop. 468*480093f4SDimitry Andric PHINode *Processed = Builder.CreatePHI(Ty, 2); 469*480093f4SDimitry Andric Processed->addIncoming(NumElements, L->getLoopPreheader()); 470*480093f4SDimitry Andric 471*480093f4SDimitry Andric // Insert the intrinsic to represent the effect of tail predication. 472*480093f4SDimitry Andric Builder.SetInsertPoint(cast<Instruction>(Predicate)); 473*480093f4SDimitry Andric ConstantInt *Factor = 474*480093f4SDimitry Andric ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements()); 475*480093f4SDimitry Andric 476*480093f4SDimitry Andric Intrinsic::ID VCTPID; 477*480093f4SDimitry Andric switch (VecTy->getNumElements()) { 478*480093f4SDimitry Andric default: 479*480093f4SDimitry Andric llvm_unreachable("unexpected number of lanes"); 480*480093f4SDimitry Andric case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 481*480093f4SDimitry Andric case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 482*480093f4SDimitry Andric case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 483*480093f4SDimitry Andric 484*480093f4SDimitry Andric // FIXME: vctp64 currently not supported because the predicate 485*480093f4SDimitry Andric // vector wants to be <2 x i1>, but v2i1 is not a legal MVE 486*480093f4SDimitry Andric // type, so problems happen at isel time. 487*480093f4SDimitry Andric // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics 488*480093f4SDimitry Andric // purposes, but takes a v4i1 instead of a v2i1. 489*480093f4SDimitry Andric } 490*480093f4SDimitry Andric Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 491*480093f4SDimitry Andric Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 492*480093f4SDimitry Andric Predicate->replaceAllUsesWith(TailPredicate); 493*480093f4SDimitry Andric NewPredicates[Predicate] = cast<Instruction>(TailPredicate); 494*480093f4SDimitry Andric 495*480093f4SDimitry Andric // Add the incoming value to the new phi. 496*480093f4SDimitry Andric // TODO: This add likely already exists in the loop. 497*480093f4SDimitry Andric Value *Remaining = Builder.CreateSub(Processed, Factor); 498*480093f4SDimitry Andric Processed->addIncoming(Remaining, L->getLoopLatch()); 499*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 500*480093f4SDimitry Andric << *Processed << "\n" 501*480093f4SDimitry Andric << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n"); 502*480093f4SDimitry Andric } 503*480093f4SDimitry Andric 504*480093f4SDimitry Andric bool MVETailPredication::TryConvert(Value *TripCount) { 505*480093f4SDimitry Andric if (!IsPredicatedVectorLoop()) { 506*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop"); 507*480093f4SDimitry Andric return false; 508*480093f4SDimitry Andric } 509*480093f4SDimitry Andric 510*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 5118bcb0991SDimitry Andric 5128bcb0991SDimitry Andric // Walk through the masked intrinsics and try to find whether the predicate 5138bcb0991SDimitry Andric // operand is generated from an induction variable. 5148bcb0991SDimitry Andric SetVector<Instruction*> Predicates; 5158bcb0991SDimitry Andric DenseMap<Instruction*, Instruction*> NewPredicates; 5168bcb0991SDimitry Andric 5178bcb0991SDimitry Andric for (auto *I : MaskedInsts) { 5188bcb0991SDimitry Andric Intrinsic::ID ID = I->getIntrinsicID(); 5198bcb0991SDimitry Andric unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 5208bcb0991SDimitry Andric auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 5218bcb0991SDimitry Andric if (!Predicate || Predicates.count(Predicate)) 5228bcb0991SDimitry Andric continue; 5238bcb0991SDimitry Andric 5248bcb0991SDimitry Andric VectorType *VecTy = getVectorType(I); 5258bcb0991SDimitry Andric Value *NumElements = ComputeElements(TripCount, VecTy); 5268bcb0991SDimitry Andric if (!NumElements) 5278bcb0991SDimitry Andric continue; 5288bcb0991SDimitry Andric 5298bcb0991SDimitry Andric if (!isTailPredicate(Predicate, NumElements)) { 530*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n"); 5318bcb0991SDimitry Andric continue; 5328bcb0991SDimitry Andric } 5338bcb0991SDimitry Andric 534*480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n"); 5358bcb0991SDimitry Andric Predicates.insert(Predicate); 5368bcb0991SDimitry Andric 537*480093f4SDimitry Andric InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements); 5388bcb0991SDimitry Andric } 5398bcb0991SDimitry Andric 5408bcb0991SDimitry Andric // Now clean up. 5418bcb0991SDimitry Andric Cleanup(NewPredicates, Predicates, L); 5428bcb0991SDimitry Andric return true; 5438bcb0991SDimitry Andric } 5448bcb0991SDimitry Andric 5458bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() { 5468bcb0991SDimitry Andric return new MVETailPredication(); 5478bcb0991SDimitry Andric } 5488bcb0991SDimitry Andric 5498bcb0991SDimitry Andric char MVETailPredication::ID = 0; 5508bcb0991SDimitry Andric 5518bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 5528bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 553