1*8bcb0991SDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 2*8bcb0991SDimitry Andric // 3*8bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*8bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*8bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*8bcb0991SDimitry Andric // 7*8bcb0991SDimitry Andric //===----------------------------------------------------------------------===// 8*8bcb0991SDimitry Andric // 9*8bcb0991SDimitry Andric /// \file 10*8bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11*8bcb0991SDimitry Andric /// branches to help accelerate DSP applications. These two extensions can be 12*8bcb0991SDimitry Andric /// combined to provide implicit vector predication within a low-overhead loop. 13*8bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the 14*8bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is 15*8bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are 16*8bcb0991SDimitry Andric /// predicated upon the iteration counter. This pass looks at these predicated 17*8bcb0991SDimitry Andric /// vector loops, that are targets for low-overhead loops, and prepares it for 18*8bcb0991SDimitry Andric /// code generation. Once the vectorizer has produced a masked loop, there's a 19*8bcb0991SDimitry Andric /// couple of final forms: 20*8bcb0991SDimitry Andric /// - A tail-predicated loop, with implicit predication. 21*8bcb0991SDimitry Andric /// - A loop containing multiple VCPT instructions, predicating multiple VPT 22*8bcb0991SDimitry Andric /// blocks of instructions operating on different vector types. 23*8bcb0991SDimitry Andric 24*8bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 25*8bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h" 26*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 27*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpander.h" 28*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h" 29*8bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 30*8bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 31*8bcb0991SDimitry Andric #include "llvm/IR/Instructions.h" 32*8bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h" 33*8bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h" 34*8bcb0991SDimitry Andric #include "llvm/Support/Debug.h" 35*8bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 36*8bcb0991SDimitry Andric #include "ARM.h" 37*8bcb0991SDimitry Andric #include "ARMSubtarget.h" 38*8bcb0991SDimitry Andric 39*8bcb0991SDimitry Andric using namespace llvm; 40*8bcb0991SDimitry Andric 41*8bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication" 42*8bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication" 43*8bcb0991SDimitry Andric 44*8bcb0991SDimitry Andric static cl::opt<bool> 45*8bcb0991SDimitry Andric DisableTailPredication("disable-mve-tail-predication", cl::Hidden, 46*8bcb0991SDimitry Andric cl::init(true), 47*8bcb0991SDimitry Andric cl::desc("Disable MVE Tail Predication")); 48*8bcb0991SDimitry Andric namespace { 49*8bcb0991SDimitry Andric 50*8bcb0991SDimitry Andric class MVETailPredication : public LoopPass { 51*8bcb0991SDimitry Andric SmallVector<IntrinsicInst*, 4> MaskedInsts; 52*8bcb0991SDimitry Andric Loop *L = nullptr; 53*8bcb0991SDimitry Andric ScalarEvolution *SE = nullptr; 54*8bcb0991SDimitry Andric TargetTransformInfo *TTI = nullptr; 55*8bcb0991SDimitry Andric 56*8bcb0991SDimitry Andric public: 57*8bcb0991SDimitry Andric static char ID; 58*8bcb0991SDimitry Andric 59*8bcb0991SDimitry Andric MVETailPredication() : LoopPass(ID) { } 60*8bcb0991SDimitry Andric 61*8bcb0991SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 62*8bcb0991SDimitry Andric AU.addRequired<ScalarEvolutionWrapperPass>(); 63*8bcb0991SDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 64*8bcb0991SDimitry Andric AU.addRequired<TargetPassConfig>(); 65*8bcb0991SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 66*8bcb0991SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 67*8bcb0991SDimitry Andric AU.setPreservesCFG(); 68*8bcb0991SDimitry Andric } 69*8bcb0991SDimitry Andric 70*8bcb0991SDimitry Andric bool runOnLoop(Loop *L, LPPassManager&) override; 71*8bcb0991SDimitry Andric 72*8bcb0991SDimitry Andric private: 73*8bcb0991SDimitry Andric 74*8bcb0991SDimitry Andric /// Perform the relevant checks on the loop and convert if possible. 75*8bcb0991SDimitry Andric bool TryConvert(Value *TripCount); 76*8bcb0991SDimitry Andric 77*8bcb0991SDimitry Andric /// Return whether this is a vectorized loop, that contains masked 78*8bcb0991SDimitry Andric /// load/stores. 79*8bcb0991SDimitry Andric bool IsPredicatedVectorLoop(); 80*8bcb0991SDimitry Andric 81*8bcb0991SDimitry Andric /// Compute a value for the total number of elements that the predicated 82*8bcb0991SDimitry Andric /// loop will process. 83*8bcb0991SDimitry Andric Value *ComputeElements(Value *TripCount, VectorType *VecTy); 84*8bcb0991SDimitry Andric 85*8bcb0991SDimitry Andric /// Is the icmp that generates an i1 vector, based upon a loop counter 86*8bcb0991SDimitry Andric /// and a limit that is defined outside the loop. 87*8bcb0991SDimitry Andric bool isTailPredicate(Instruction *Predicate, Value *NumElements); 88*8bcb0991SDimitry Andric }; 89*8bcb0991SDimitry Andric 90*8bcb0991SDimitry Andric } // end namespace 91*8bcb0991SDimitry Andric 92*8bcb0991SDimitry Andric static bool IsDecrement(Instruction &I) { 93*8bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 94*8bcb0991SDimitry Andric if (!Call) 95*8bcb0991SDimitry Andric return false; 96*8bcb0991SDimitry Andric 97*8bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 98*8bcb0991SDimitry Andric return ID == Intrinsic::loop_decrement_reg; 99*8bcb0991SDimitry Andric } 100*8bcb0991SDimitry Andric 101*8bcb0991SDimitry Andric static bool IsMasked(Instruction *I) { 102*8bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(I); 103*8bcb0991SDimitry Andric if (!Call) 104*8bcb0991SDimitry Andric return false; 105*8bcb0991SDimitry Andric 106*8bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 107*8bcb0991SDimitry Andric // TODO: Support gather/scatter expand/compress operations. 108*8bcb0991SDimitry Andric return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 109*8bcb0991SDimitry Andric } 110*8bcb0991SDimitry Andric 111*8bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 112*8bcb0991SDimitry Andric if (skipLoop(L) || DisableTailPredication) 113*8bcb0991SDimitry Andric return false; 114*8bcb0991SDimitry Andric 115*8bcb0991SDimitry Andric Function &F = *L->getHeader()->getParent(); 116*8bcb0991SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 117*8bcb0991SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 118*8bcb0991SDimitry Andric auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 119*8bcb0991SDimitry Andric TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 120*8bcb0991SDimitry Andric SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 121*8bcb0991SDimitry Andric this->L = L; 122*8bcb0991SDimitry Andric 123*8bcb0991SDimitry Andric // The MVE and LOB extensions are combined to enable tail-predication, but 124*8bcb0991SDimitry Andric // there's nothing preventing us from generating VCTP instructions for v8.1m. 125*8bcb0991SDimitry Andric if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 126*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Not a v8.1m.main+mve target.\n"); 127*8bcb0991SDimitry Andric return false; 128*8bcb0991SDimitry Andric } 129*8bcb0991SDimitry Andric 130*8bcb0991SDimitry Andric BasicBlock *Preheader = L->getLoopPreheader(); 131*8bcb0991SDimitry Andric if (!Preheader) 132*8bcb0991SDimitry Andric return false; 133*8bcb0991SDimitry Andric 134*8bcb0991SDimitry Andric auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 135*8bcb0991SDimitry Andric for (auto &I : *BB) { 136*8bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 137*8bcb0991SDimitry Andric if (!Call) 138*8bcb0991SDimitry Andric continue; 139*8bcb0991SDimitry Andric 140*8bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 141*8bcb0991SDimitry Andric if (ID == Intrinsic::set_loop_iterations || 142*8bcb0991SDimitry Andric ID == Intrinsic::test_set_loop_iterations) 143*8bcb0991SDimitry Andric return cast<IntrinsicInst>(&I); 144*8bcb0991SDimitry Andric } 145*8bcb0991SDimitry Andric return nullptr; 146*8bcb0991SDimitry Andric }; 147*8bcb0991SDimitry Andric 148*8bcb0991SDimitry Andric // Look for the hardware loop intrinsic that sets the iteration count. 149*8bcb0991SDimitry Andric IntrinsicInst *Setup = FindLoopIterations(Preheader); 150*8bcb0991SDimitry Andric 151*8bcb0991SDimitry Andric // The test.set iteration could live in the pre- preheader. 152*8bcb0991SDimitry Andric if (!Setup) { 153*8bcb0991SDimitry Andric if (!Preheader->getSinglePredecessor()) 154*8bcb0991SDimitry Andric return false; 155*8bcb0991SDimitry Andric Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 156*8bcb0991SDimitry Andric if (!Setup) 157*8bcb0991SDimitry Andric return false; 158*8bcb0991SDimitry Andric } 159*8bcb0991SDimitry Andric 160*8bcb0991SDimitry Andric // Search for the hardware loop intrinic that decrements the loop counter. 161*8bcb0991SDimitry Andric IntrinsicInst *Decrement = nullptr; 162*8bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 163*8bcb0991SDimitry Andric for (auto &I : *BB) { 164*8bcb0991SDimitry Andric if (IsDecrement(I)) { 165*8bcb0991SDimitry Andric Decrement = cast<IntrinsicInst>(&I); 166*8bcb0991SDimitry Andric break; 167*8bcb0991SDimitry Andric } 168*8bcb0991SDimitry Andric } 169*8bcb0991SDimitry Andric } 170*8bcb0991SDimitry Andric 171*8bcb0991SDimitry Andric if (!Decrement) 172*8bcb0991SDimitry Andric return false; 173*8bcb0991SDimitry Andric 174*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Running on Loop: " << *L 175*8bcb0991SDimitry Andric << *Setup << "\n" 176*8bcb0991SDimitry Andric << *Decrement << "\n"); 177*8bcb0991SDimitry Andric bool Changed = TryConvert(Setup->getArgOperand(0)); 178*8bcb0991SDimitry Andric return Changed; 179*8bcb0991SDimitry Andric } 180*8bcb0991SDimitry Andric 181*8bcb0991SDimitry Andric bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { 182*8bcb0991SDimitry Andric // Look for the following: 183*8bcb0991SDimitry Andric 184*8bcb0991SDimitry Andric // %trip.count.minus.1 = add i32 %N, -1 185*8bcb0991SDimitry Andric // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 186*8bcb0991SDimitry Andric // i32 %trip.count.minus.1, i32 0 187*8bcb0991SDimitry Andric // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 188*8bcb0991SDimitry Andric // <4 x i32> undef, 189*8bcb0991SDimitry Andric // <4 x i32> zeroinitializer 190*8bcb0991SDimitry Andric // ... 191*8bcb0991SDimitry Andric // ... 192*8bcb0991SDimitry Andric // %index = phi i32 193*8bcb0991SDimitry Andric // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 194*8bcb0991SDimitry Andric // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 195*8bcb0991SDimitry Andric // <4 x i32> undef, 196*8bcb0991SDimitry Andric // <4 x i32> zeroinitializer 197*8bcb0991SDimitry Andric // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 198*8bcb0991SDimitry Andric // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 199*8bcb0991SDimitry Andric 200*8bcb0991SDimitry Andric // And return whether V == %pred. 201*8bcb0991SDimitry Andric 202*8bcb0991SDimitry Andric using namespace PatternMatch; 203*8bcb0991SDimitry Andric 204*8bcb0991SDimitry Andric CmpInst::Predicate Pred; 205*8bcb0991SDimitry Andric Instruction *Shuffle = nullptr; 206*8bcb0991SDimitry Andric Instruction *Induction = nullptr; 207*8bcb0991SDimitry Andric 208*8bcb0991SDimitry Andric // The vector icmp 209*8bcb0991SDimitry Andric if (!match(I, m_ICmp(Pred, m_Instruction(Induction), 210*8bcb0991SDimitry Andric m_Instruction(Shuffle))) || 211*8bcb0991SDimitry Andric Pred != ICmpInst::ICMP_ULE || !L->isLoopInvariant(Shuffle)) 212*8bcb0991SDimitry Andric return false; 213*8bcb0991SDimitry Andric 214*8bcb0991SDimitry Andric // First find the stuff outside the loop which is setting up the limit 215*8bcb0991SDimitry Andric // vector.... 216*8bcb0991SDimitry Andric // The invariant shuffle that broadcast the limit into a vector. 217*8bcb0991SDimitry Andric Instruction *Insert = nullptr; 218*8bcb0991SDimitry Andric if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 219*8bcb0991SDimitry Andric m_Zero()))) 220*8bcb0991SDimitry Andric return false; 221*8bcb0991SDimitry Andric 222*8bcb0991SDimitry Andric // Insert the limit into a vector. 223*8bcb0991SDimitry Andric Instruction *BECount = nullptr; 224*8bcb0991SDimitry Andric if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount), 225*8bcb0991SDimitry Andric m_Zero()))) 226*8bcb0991SDimitry Andric return false; 227*8bcb0991SDimitry Andric 228*8bcb0991SDimitry Andric // The limit calculation, backedge count. 229*8bcb0991SDimitry Andric Value *TripCount = nullptr; 230*8bcb0991SDimitry Andric if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 231*8bcb0991SDimitry Andric return false; 232*8bcb0991SDimitry Andric 233*8bcb0991SDimitry Andric if (TripCount != NumElements) 234*8bcb0991SDimitry Andric return false; 235*8bcb0991SDimitry Andric 236*8bcb0991SDimitry Andric // Now back to searching inside the loop body... 237*8bcb0991SDimitry Andric // Find the add with takes the index iv and adds a constant vector to it. 238*8bcb0991SDimitry Andric Instruction *BroadcastSplat = nullptr; 239*8bcb0991SDimitry Andric Constant *Const = nullptr; 240*8bcb0991SDimitry Andric if (!match(Induction, m_Add(m_Instruction(BroadcastSplat), 241*8bcb0991SDimitry Andric m_Constant(Const)))) 242*8bcb0991SDimitry Andric return false; 243*8bcb0991SDimitry Andric 244*8bcb0991SDimitry Andric // Check that we're adding <0, 1, 2, 3... 245*8bcb0991SDimitry Andric if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 246*8bcb0991SDimitry Andric for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 247*8bcb0991SDimitry Andric if (CDS->getElementAsInteger(i) != i) 248*8bcb0991SDimitry Andric return false; 249*8bcb0991SDimitry Andric } 250*8bcb0991SDimitry Andric } else 251*8bcb0991SDimitry Andric return false; 252*8bcb0991SDimitry Andric 253*8bcb0991SDimitry Andric // The shuffle which broadcasts the index iv into a vector. 254*8bcb0991SDimitry Andric if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 255*8bcb0991SDimitry Andric m_Zero()))) 256*8bcb0991SDimitry Andric return false; 257*8bcb0991SDimitry Andric 258*8bcb0991SDimitry Andric // The insert element which initialises a vector with the index iv. 259*8bcb0991SDimitry Andric Instruction *IV = nullptr; 260*8bcb0991SDimitry Andric if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 261*8bcb0991SDimitry Andric return false; 262*8bcb0991SDimitry Andric 263*8bcb0991SDimitry Andric // The index iv. 264*8bcb0991SDimitry Andric auto *Phi = dyn_cast<PHINode>(IV); 265*8bcb0991SDimitry Andric if (!Phi) 266*8bcb0991SDimitry Andric return false; 267*8bcb0991SDimitry Andric 268*8bcb0991SDimitry Andric // TODO: Don't think we need to check the entry value. 269*8bcb0991SDimitry Andric Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 270*8bcb0991SDimitry Andric if (!match(OnEntry, m_Zero())) 271*8bcb0991SDimitry Andric return false; 272*8bcb0991SDimitry Andric 273*8bcb0991SDimitry Andric Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 274*8bcb0991SDimitry Andric unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 275*8bcb0991SDimitry Andric 276*8bcb0991SDimitry Andric Instruction *LHS = nullptr; 277*8bcb0991SDimitry Andric if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 278*8bcb0991SDimitry Andric return false; 279*8bcb0991SDimitry Andric 280*8bcb0991SDimitry Andric return LHS == Phi; 281*8bcb0991SDimitry Andric } 282*8bcb0991SDimitry Andric 283*8bcb0991SDimitry Andric static VectorType* getVectorType(IntrinsicInst *I) { 284*8bcb0991SDimitry Andric unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 285*8bcb0991SDimitry Andric auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 286*8bcb0991SDimitry Andric return cast<VectorType>(PtrTy->getElementType()); 287*8bcb0991SDimitry Andric } 288*8bcb0991SDimitry Andric 289*8bcb0991SDimitry Andric bool MVETailPredication::IsPredicatedVectorLoop() { 290*8bcb0991SDimitry Andric // Check that the loop contains at least one masked load/store intrinsic. 291*8bcb0991SDimitry Andric // We only support 'normal' vector instructions - other than masked 292*8bcb0991SDimitry Andric // load/stores. 293*8bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 294*8bcb0991SDimitry Andric for (auto &I : *BB) { 295*8bcb0991SDimitry Andric if (IsMasked(&I)) { 296*8bcb0991SDimitry Andric VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 297*8bcb0991SDimitry Andric unsigned Lanes = VecTy->getNumElements(); 298*8bcb0991SDimitry Andric unsigned ElementWidth = VecTy->getScalarSizeInBits(); 299*8bcb0991SDimitry Andric // MVE vectors are 128-bit, but don't support 128 x i1. 300*8bcb0991SDimitry Andric // TODO: Can we support vectors larger than 128-bits? 301*8bcb0991SDimitry Andric unsigned MaxWidth = TTI->getRegisterBitWidth(true); 302*8bcb0991SDimitry Andric if (Lanes * ElementWidth != MaxWidth || Lanes == MaxWidth) 303*8bcb0991SDimitry Andric return false; 304*8bcb0991SDimitry Andric MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 305*8bcb0991SDimitry Andric } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 306*8bcb0991SDimitry Andric for (auto &U : Int->args()) { 307*8bcb0991SDimitry Andric if (isa<VectorType>(U->getType())) 308*8bcb0991SDimitry Andric return false; 309*8bcb0991SDimitry Andric } 310*8bcb0991SDimitry Andric } 311*8bcb0991SDimitry Andric } 312*8bcb0991SDimitry Andric } 313*8bcb0991SDimitry Andric 314*8bcb0991SDimitry Andric return !MaskedInsts.empty(); 315*8bcb0991SDimitry Andric } 316*8bcb0991SDimitry Andric 317*8bcb0991SDimitry Andric Value* MVETailPredication::ComputeElements(Value *TripCount, 318*8bcb0991SDimitry Andric VectorType *VecTy) { 319*8bcb0991SDimitry Andric const SCEV *TripCountSE = SE->getSCEV(TripCount); 320*8bcb0991SDimitry Andric ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()), 321*8bcb0991SDimitry Andric VecTy->getNumElements()); 322*8bcb0991SDimitry Andric 323*8bcb0991SDimitry Andric if (VF->equalsInt(1)) 324*8bcb0991SDimitry Andric return nullptr; 325*8bcb0991SDimitry Andric 326*8bcb0991SDimitry Andric // TODO: Support constant trip counts. 327*8bcb0991SDimitry Andric auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { 328*8bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 329*8bcb0991SDimitry Andric if (Const->getAPInt() != -VF->getValue()) 330*8bcb0991SDimitry Andric return nullptr; 331*8bcb0991SDimitry Andric } else 332*8bcb0991SDimitry Andric return nullptr; 333*8bcb0991SDimitry Andric return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 334*8bcb0991SDimitry Andric }; 335*8bcb0991SDimitry Andric 336*8bcb0991SDimitry Andric auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { 337*8bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 338*8bcb0991SDimitry Andric if (Const->getValue() != VF) 339*8bcb0991SDimitry Andric return nullptr; 340*8bcb0991SDimitry Andric } else 341*8bcb0991SDimitry Andric return nullptr; 342*8bcb0991SDimitry Andric return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 343*8bcb0991SDimitry Andric }; 344*8bcb0991SDimitry Andric 345*8bcb0991SDimitry Andric auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { 346*8bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 347*8bcb0991SDimitry Andric if (Const->getValue() != VF) 348*8bcb0991SDimitry Andric return nullptr; 349*8bcb0991SDimitry Andric } else 350*8bcb0991SDimitry Andric return nullptr; 351*8bcb0991SDimitry Andric 352*8bcb0991SDimitry Andric if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 353*8bcb0991SDimitry Andric if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 354*8bcb0991SDimitry Andric if (Const->getAPInt() != (VF->getValue() - 1)) 355*8bcb0991SDimitry Andric return nullptr; 356*8bcb0991SDimitry Andric } else 357*8bcb0991SDimitry Andric return nullptr; 358*8bcb0991SDimitry Andric 359*8bcb0991SDimitry Andric return RoundUp->getOperand(1); 360*8bcb0991SDimitry Andric } 361*8bcb0991SDimitry Andric return nullptr; 362*8bcb0991SDimitry Andric }; 363*8bcb0991SDimitry Andric 364*8bcb0991SDimitry Andric // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 365*8bcb0991SDimitry Andric // determine the numbers of elements instead? Looks like this is what is used 366*8bcb0991SDimitry Andric // for delinearization, but I'm not sure if it can be applied to the 367*8bcb0991SDimitry Andric // vectorized form - at least not without a bit more work than I feel 368*8bcb0991SDimitry Andric // comfortable with. 369*8bcb0991SDimitry Andric 370*8bcb0991SDimitry Andric // Search for Elems in the following SCEV: 371*8bcb0991SDimitry Andric // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 372*8bcb0991SDimitry Andric const SCEV *Elems = nullptr; 373*8bcb0991SDimitry Andric if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 374*8bcb0991SDimitry Andric if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 375*8bcb0991SDimitry Andric if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 376*8bcb0991SDimitry Andric if (auto *Mul = VisitAdd(Add)) 377*8bcb0991SDimitry Andric if (auto *Div = VisitMul(Mul)) 378*8bcb0991SDimitry Andric if (auto *Res = VisitDiv(Div)) 379*8bcb0991SDimitry Andric Elems = Res; 380*8bcb0991SDimitry Andric 381*8bcb0991SDimitry Andric if (!Elems) 382*8bcb0991SDimitry Andric return nullptr; 383*8bcb0991SDimitry Andric 384*8bcb0991SDimitry Andric Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 385*8bcb0991SDimitry Andric if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 386*8bcb0991SDimitry Andric return nullptr; 387*8bcb0991SDimitry Andric 388*8bcb0991SDimitry Andric auto DL = L->getHeader()->getModule()->getDataLayout(); 389*8bcb0991SDimitry Andric SCEVExpander Expander(*SE, DL, "elements"); 390*8bcb0991SDimitry Andric return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 391*8bcb0991SDimitry Andric } 392*8bcb0991SDimitry Andric 393*8bcb0991SDimitry Andric // Look through the exit block to see whether there's a duplicate predicate 394*8bcb0991SDimitry Andric // instruction. This can happen when we need to perform a select on values 395*8bcb0991SDimitry Andric // from the last and previous iteration. Instead of doing a straight 396*8bcb0991SDimitry Andric // replacement of that predicate with the vctp, clone the vctp and place it 397*8bcb0991SDimitry Andric // in the block. This means that the VPR doesn't have to be live into the 398*8bcb0991SDimitry Andric // exit block which should make it easier to convert this loop into a proper 399*8bcb0991SDimitry Andric // tail predicated loop. 400*8bcb0991SDimitry Andric static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 401*8bcb0991SDimitry Andric SetVector<Instruction*> &MaybeDead, Loop *L) { 402*8bcb0991SDimitry Andric if (BasicBlock *Exit = L->getUniqueExitBlock()) { 403*8bcb0991SDimitry Andric for (auto &Pair : NewPredicates) { 404*8bcb0991SDimitry Andric Instruction *OldPred = Pair.first; 405*8bcb0991SDimitry Andric Instruction *NewPred = Pair.second; 406*8bcb0991SDimitry Andric 407*8bcb0991SDimitry Andric for (auto &I : *Exit) { 408*8bcb0991SDimitry Andric if (I.isSameOperationAs(OldPred)) { 409*8bcb0991SDimitry Andric Instruction *PredClone = NewPred->clone(); 410*8bcb0991SDimitry Andric PredClone->insertBefore(&I); 411*8bcb0991SDimitry Andric I.replaceAllUsesWith(PredClone); 412*8bcb0991SDimitry Andric MaybeDead.insert(&I); 413*8bcb0991SDimitry Andric break; 414*8bcb0991SDimitry Andric } 415*8bcb0991SDimitry Andric } 416*8bcb0991SDimitry Andric } 417*8bcb0991SDimitry Andric } 418*8bcb0991SDimitry Andric 419*8bcb0991SDimitry Andric // Drop references and add operands to check for dead. 420*8bcb0991SDimitry Andric SmallPtrSet<Instruction*, 4> Dead; 421*8bcb0991SDimitry Andric while (!MaybeDead.empty()) { 422*8bcb0991SDimitry Andric auto *I = MaybeDead.front(); 423*8bcb0991SDimitry Andric MaybeDead.remove(I); 424*8bcb0991SDimitry Andric if (I->hasNUsesOrMore(1)) 425*8bcb0991SDimitry Andric continue; 426*8bcb0991SDimitry Andric 427*8bcb0991SDimitry Andric for (auto &U : I->operands()) { 428*8bcb0991SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(U)) 429*8bcb0991SDimitry Andric MaybeDead.insert(OpI); 430*8bcb0991SDimitry Andric } 431*8bcb0991SDimitry Andric I->dropAllReferences(); 432*8bcb0991SDimitry Andric Dead.insert(I); 433*8bcb0991SDimitry Andric } 434*8bcb0991SDimitry Andric 435*8bcb0991SDimitry Andric for (auto *I : Dead) 436*8bcb0991SDimitry Andric I->eraseFromParent(); 437*8bcb0991SDimitry Andric 438*8bcb0991SDimitry Andric for (auto I : L->blocks()) 439*8bcb0991SDimitry Andric DeleteDeadPHIs(I); 440*8bcb0991SDimitry Andric } 441*8bcb0991SDimitry Andric 442*8bcb0991SDimitry Andric bool MVETailPredication::TryConvert(Value *TripCount) { 443*8bcb0991SDimitry Andric if (!IsPredicatedVectorLoop()) 444*8bcb0991SDimitry Andric return false; 445*8bcb0991SDimitry Andric 446*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Found predicated vector loop.\n"); 447*8bcb0991SDimitry Andric 448*8bcb0991SDimitry Andric // Walk through the masked intrinsics and try to find whether the predicate 449*8bcb0991SDimitry Andric // operand is generated from an induction variable. 450*8bcb0991SDimitry Andric Module *M = L->getHeader()->getModule(); 451*8bcb0991SDimitry Andric Type *Ty = IntegerType::get(M->getContext(), 32); 452*8bcb0991SDimitry Andric SetVector<Instruction*> Predicates; 453*8bcb0991SDimitry Andric DenseMap<Instruction*, Instruction*> NewPredicates; 454*8bcb0991SDimitry Andric 455*8bcb0991SDimitry Andric for (auto *I : MaskedInsts) { 456*8bcb0991SDimitry Andric Intrinsic::ID ID = I->getIntrinsicID(); 457*8bcb0991SDimitry Andric unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 458*8bcb0991SDimitry Andric auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 459*8bcb0991SDimitry Andric if (!Predicate || Predicates.count(Predicate)) 460*8bcb0991SDimitry Andric continue; 461*8bcb0991SDimitry Andric 462*8bcb0991SDimitry Andric VectorType *VecTy = getVectorType(I); 463*8bcb0991SDimitry Andric Value *NumElements = ComputeElements(TripCount, VecTy); 464*8bcb0991SDimitry Andric if (!NumElements) 465*8bcb0991SDimitry Andric continue; 466*8bcb0991SDimitry Andric 467*8bcb0991SDimitry Andric if (!isTailPredicate(Predicate, NumElements)) { 468*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Not tail predicate: " << *Predicate << "\n"); 469*8bcb0991SDimitry Andric continue; 470*8bcb0991SDimitry Andric } 471*8bcb0991SDimitry Andric 472*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Found tail predicate: " << *Predicate << "\n"); 473*8bcb0991SDimitry Andric Predicates.insert(Predicate); 474*8bcb0991SDimitry Andric 475*8bcb0991SDimitry Andric // Insert a phi to count the number of elements processed by the loop. 476*8bcb0991SDimitry Andric IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 477*8bcb0991SDimitry Andric PHINode *Processed = Builder.CreatePHI(Ty, 2); 478*8bcb0991SDimitry Andric Processed->addIncoming(NumElements, L->getLoopPreheader()); 479*8bcb0991SDimitry Andric 480*8bcb0991SDimitry Andric // Insert the intrinsic to represent the effect of tail predication. 481*8bcb0991SDimitry Andric Builder.SetInsertPoint(cast<Instruction>(Predicate)); 482*8bcb0991SDimitry Andric ConstantInt *Factor = 483*8bcb0991SDimitry Andric ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements()); 484*8bcb0991SDimitry Andric Intrinsic::ID VCTPID; 485*8bcb0991SDimitry Andric switch (VecTy->getNumElements()) { 486*8bcb0991SDimitry Andric default: 487*8bcb0991SDimitry Andric llvm_unreachable("unexpected number of lanes"); 488*8bcb0991SDimitry Andric case 2: VCTPID = Intrinsic::arm_vctp64; break; 489*8bcb0991SDimitry Andric case 4: VCTPID = Intrinsic::arm_vctp32; break; 490*8bcb0991SDimitry Andric case 8: VCTPID = Intrinsic::arm_vctp16; break; 491*8bcb0991SDimitry Andric case 16: VCTPID = Intrinsic::arm_vctp8; break; 492*8bcb0991SDimitry Andric } 493*8bcb0991SDimitry Andric Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 494*8bcb0991SDimitry Andric Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 495*8bcb0991SDimitry Andric Predicate->replaceAllUsesWith(TailPredicate); 496*8bcb0991SDimitry Andric NewPredicates[Predicate] = cast<Instruction>(TailPredicate); 497*8bcb0991SDimitry Andric 498*8bcb0991SDimitry Andric // Add the incoming value to the new phi. 499*8bcb0991SDimitry Andric // TODO: This add likely already exists in the loop. 500*8bcb0991SDimitry Andric Value *Remaining = Builder.CreateSub(Processed, Factor); 501*8bcb0991SDimitry Andric Processed->addIncoming(Remaining, L->getLoopLatch()); 502*8bcb0991SDimitry Andric LLVM_DEBUG(dbgs() << "TP: Insert processed elements phi: " 503*8bcb0991SDimitry Andric << *Processed << "\n" 504*8bcb0991SDimitry Andric << "TP: Inserted VCTP: " << *TailPredicate << "\n"); 505*8bcb0991SDimitry Andric } 506*8bcb0991SDimitry Andric 507*8bcb0991SDimitry Andric // Now clean up. 508*8bcb0991SDimitry Andric Cleanup(NewPredicates, Predicates, L); 509*8bcb0991SDimitry Andric return true; 510*8bcb0991SDimitry Andric } 511*8bcb0991SDimitry Andric 512*8bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() { 513*8bcb0991SDimitry Andric return new MVETailPredication(); 514*8bcb0991SDimitry Andric } 515*8bcb0991SDimitry Andric 516*8bcb0991SDimitry Andric char MVETailPredication::ID = 0; 517*8bcb0991SDimitry Andric 518*8bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 519*8bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 520