15ffd83dbSDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===// 28bcb0991SDimitry Andric // 38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68bcb0991SDimitry Andric // 78bcb0991SDimitry Andric //===----------------------------------------------------------------------===// 88bcb0991SDimitry Andric // 98bcb0991SDimitry Andric /// \file 108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 115ffd83dbSDimitry Andric /// branches to help accelerate DSP applications. These two extensions, 125ffd83dbSDimitry Andric /// combined with a new form of predication called tail-predication, can be used 135ffd83dbSDimitry Andric /// to provide implicit vector predication within a low-overhead loop. 145ffd83dbSDimitry Andric /// This is implicit because the predicate of active/inactive lanes is 155ffd83dbSDimitry Andric /// calculated by hardware, and thus does not need to be explicitly passed 165ffd83dbSDimitry Andric /// to vector instructions. The instructions responsible for this are the 175ffd83dbSDimitry Andric /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the 185ffd83dbSDimitry Andric /// the total number of data elements processed by the loop. The loop-end 195ffd83dbSDimitry Andric /// LETP instruction is responsible for decrementing and setting the remaining 205ffd83dbSDimitry Andric /// elements to be processed and generating the mask of active lanes. 215ffd83dbSDimitry Andric /// 228bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the 238bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is 248bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are 25e8d8bef9SDimitry Andric /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these 26e8d8bef9SDimitry Andric /// get.active.lane.mask intrinsic and attempts to convert them to VCTP 27e8d8bef9SDimitry Andric /// instructions. This will be picked up by the ARM Low-overhead loop pass later 28e8d8bef9SDimitry Andric /// in the backend, which performs the final transformation to a DLSTP or WLSTP 29e8d8bef9SDimitry Andric /// tail-predicated loop. 30e8d8bef9SDimitry Andric // 31e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===// 328bcb0991SDimitry Andric 33480093f4SDimitry Andric #include "ARM.h" 34480093f4SDimitry Andric #include "ARMSubtarget.h" 355ffd83dbSDimitry Andric #include "ARMTargetTransformInfo.h" 368bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 378bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h" 388bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 398bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h" 405ffd83dbSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h" 418bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 4206c3fb27SDimitry Andric #include "llvm/Analysis/ValueTracking.h" 438bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 448bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h" 45480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 46480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 478bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h" 485ffd83dbSDimitry Andric #include "llvm/InitializePasses.h" 498bcb0991SDimitry Andric #include "llvm/Support/Debug.h" 508bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 51e8d8bef9SDimitry Andric #include "llvm/Transforms/Utils/Local.h" 525ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 535ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 548bcb0991SDimitry Andric 558bcb0991SDimitry Andric using namespace llvm; 568bcb0991SDimitry Andric 578bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication" 588bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication" 598bcb0991SDimitry Andric 605ffd83dbSDimitry Andric cl::opt<TailPredication::Mode> EnableTailPredication( 61e8d8bef9SDimitry Andric "tail-predication", cl::desc("MVE tail-predication pass options"), 62e8d8bef9SDimitry Andric cl::init(TailPredication::Enabled), 635ffd83dbSDimitry Andric cl::values(clEnumValN(TailPredication::Disabled, "disabled", 645ffd83dbSDimitry Andric "Don't tail-predicate loops"), 655ffd83dbSDimitry Andric clEnumValN(TailPredication::EnabledNoReductions, 665ffd83dbSDimitry Andric "enabled-no-reductions", 675ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops"), 685ffd83dbSDimitry Andric clEnumValN(TailPredication::Enabled, 695ffd83dbSDimitry Andric "enabled", 705ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops"), 715ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabledNoReductions, 725ffd83dbSDimitry Andric "force-enabled-no-reductions", 735ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops, " 745ffd83dbSDimitry Andric "and force this which might be unsafe"), 755ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabled, 765ffd83dbSDimitry Andric "force-enabled", 775ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops, " 785ffd83dbSDimitry Andric "and force this which might be unsafe"))); 795ffd83dbSDimitry Andric 805ffd83dbSDimitry Andric 818bcb0991SDimitry Andric namespace { 828bcb0991SDimitry Andric 838bcb0991SDimitry Andric class MVETailPredication : public LoopPass { 848bcb0991SDimitry Andric SmallVector<IntrinsicInst*, 4> MaskedInsts; 858bcb0991SDimitry Andric Loop *L = nullptr; 868bcb0991SDimitry Andric ScalarEvolution *SE = nullptr; 878bcb0991SDimitry Andric TargetTransformInfo *TTI = nullptr; 885ffd83dbSDimitry Andric const ARMSubtarget *ST = nullptr; 898bcb0991SDimitry Andric 908bcb0991SDimitry Andric public: 918bcb0991SDimitry Andric static char ID; 928bcb0991SDimitry Andric 938bcb0991SDimitry Andric MVETailPredication() : LoopPass(ID) { } 948bcb0991SDimitry Andric 958bcb0991SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 968bcb0991SDimitry Andric AU.addRequired<ScalarEvolutionWrapperPass>(); 978bcb0991SDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 988bcb0991SDimitry Andric AU.addRequired<TargetPassConfig>(); 998bcb0991SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 1008bcb0991SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 1018bcb0991SDimitry Andric AU.setPreservesCFG(); 1028bcb0991SDimitry Andric } 1038bcb0991SDimitry Andric 1048bcb0991SDimitry Andric bool runOnLoop(Loop *L, LPPassManager&) override; 1058bcb0991SDimitry Andric 1068bcb0991SDimitry Andric private: 107e8d8bef9SDimitry Andric /// Perform the relevant checks on the loop and convert active lane masks if 108e8d8bef9SDimitry Andric /// possible. 109e8d8bef9SDimitry Andric bool TryConvertActiveLaneMask(Value *TripCount); 1108bcb0991SDimitry Andric 111e8d8bef9SDimitry Andric /// Perform several checks on the arguments of @llvm.get.active.lane.mask 112e8d8bef9SDimitry Andric /// intrinsic. E.g., check that the loop induction variable and the element 113e8d8bef9SDimitry Andric /// count are of the form we expect, and also perform overflow checks for 114e8d8bef9SDimitry Andric /// the new expressions that are created. 11506c3fb27SDimitry Andric const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount); 116480093f4SDimitry Andric 117480093f4SDimitry Andric /// Insert the intrinsic to represent the effect of tail predication. 11806c3fb27SDimitry Andric void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *Start); 1198bcb0991SDimitry Andric }; 1208bcb0991SDimitry Andric 1218bcb0991SDimitry Andric } // end namespace 1228bcb0991SDimitry Andric 1238bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 1245ffd83dbSDimitry Andric if (skipLoop(L) || !EnableTailPredication) 1258bcb0991SDimitry Andric return false; 1268bcb0991SDimitry Andric 1275ffd83dbSDimitry Andric MaskedInsts.clear(); 1288bcb0991SDimitry Andric Function &F = *L->getHeader()->getParent(); 1298bcb0991SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 1308bcb0991SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 1315ffd83dbSDimitry Andric ST = &TM.getSubtarget<ARMSubtarget>(F); 1328bcb0991SDimitry Andric TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1338bcb0991SDimitry Andric SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 1348bcb0991SDimitry Andric this->L = L; 1358bcb0991SDimitry Andric 1368bcb0991SDimitry Andric // The MVE and LOB extensions are combined to enable tail-predication, but 1378bcb0991SDimitry Andric // there's nothing preventing us from generating VCTP instructions for v8.1m. 1388bcb0991SDimitry Andric if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 139480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 1408bcb0991SDimitry Andric return false; 1418bcb0991SDimitry Andric } 1428bcb0991SDimitry Andric 1438bcb0991SDimitry Andric BasicBlock *Preheader = L->getLoopPreheader(); 1448bcb0991SDimitry Andric if (!Preheader) 1458bcb0991SDimitry Andric return false; 1468bcb0991SDimitry Andric 1478bcb0991SDimitry Andric auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 1488bcb0991SDimitry Andric for (auto &I : *BB) { 1498bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 1508bcb0991SDimitry Andric if (!Call) 1518bcb0991SDimitry Andric continue; 1528bcb0991SDimitry Andric 1538bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 154e8d8bef9SDimitry Andric if (ID == Intrinsic::start_loop_iterations || 155fe6060f1SDimitry Andric ID == Intrinsic::test_start_loop_iterations) 1568bcb0991SDimitry Andric return cast<IntrinsicInst>(&I); 1578bcb0991SDimitry Andric } 1588bcb0991SDimitry Andric return nullptr; 1598bcb0991SDimitry Andric }; 1608bcb0991SDimitry Andric 1618bcb0991SDimitry Andric // Look for the hardware loop intrinsic that sets the iteration count. 1628bcb0991SDimitry Andric IntrinsicInst *Setup = FindLoopIterations(Preheader); 1638bcb0991SDimitry Andric 1648bcb0991SDimitry Andric // The test.set iteration could live in the pre-preheader. 1658bcb0991SDimitry Andric if (!Setup) { 1668bcb0991SDimitry Andric if (!Preheader->getSinglePredecessor()) 1678bcb0991SDimitry Andric return false; 1688bcb0991SDimitry Andric Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 1698bcb0991SDimitry Andric if (!Setup) 1708bcb0991SDimitry Andric return false; 1718bcb0991SDimitry Andric } 1728bcb0991SDimitry Andric 173e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"); 1748bcb0991SDimitry Andric 175e8d8bef9SDimitry Andric bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0)); 1768bcb0991SDimitry Andric 177e8d8bef9SDimitry Andric return Changed; 1788bcb0991SDimitry Andric } 1798bcb0991SDimitry Andric 1805ffd83dbSDimitry Andric // The active lane intrinsic has this form: 1815ffd83dbSDimitry Andric // 182e8d8bef9SDimitry Andric // @llvm.get.active.lane.mask(IV, TC) 1835ffd83dbSDimitry Andric // 1845ffd83dbSDimitry Andric // Here we perform checks that this intrinsic behaves as expected, 1855ffd83dbSDimitry Andric // which means: 1865ffd83dbSDimitry Andric // 187e8d8bef9SDimitry Andric // 1) Check that the TripCount (TC) belongs to this loop (originally). 188e8d8bef9SDimitry Andric // 2) The element count (TC) needs to be sufficiently large that the decrement 189e8d8bef9SDimitry Andric // of element counter doesn't overflow, which means that we need to prove: 1905ffd83dbSDimitry Andric // ceil(ElementCount / VectorWidth) >= TripCount 1915ffd83dbSDimitry Andric // by rounding up ElementCount up: 1925ffd83dbSDimitry Andric // ((ElementCount + (VectorWidth - 1)) / VectorWidth 1935ffd83dbSDimitry Andric // and evaluate if expression isKnownNonNegative: 1945ffd83dbSDimitry Andric // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount 1955ffd83dbSDimitry Andric // 3) The IV must be an induction phi with an increment equal to the 1965ffd83dbSDimitry Andric // vector width. 19706c3fb27SDimitry Andric const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, 198e8d8bef9SDimitry Andric Value *TripCount) { 1995ffd83dbSDimitry Andric bool ForceTailPredication = 2005ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabledNoReductions || 2015ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabled; 2025ffd83dbSDimitry Andric 203e8d8bef9SDimitry Andric Value *ElemCount = ActiveLaneMask->getOperand(1); 2044652422eSDimitry Andric bool Changed = false; 2054652422eSDimitry Andric if (!L->makeLoopInvariant(ElemCount, Changed)) 20606c3fb27SDimitry Andric return nullptr; 2074652422eSDimitry Andric 208e8d8bef9SDimitry Andric auto *EC= SE->getSCEV(ElemCount); 2095ffd83dbSDimitry Andric auto *TC = SE->getSCEV(TripCount); 210e8d8bef9SDimitry Andric int VectorWidth = 211e8d8bef9SDimitry Andric cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements(); 2120eae32dcSDimitry Andric if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 && 2130eae32dcSDimitry Andric VectorWidth != 16) 21406c3fb27SDimitry Andric return nullptr; 215e8d8bef9SDimitry Andric ConstantInt *ConstElemCount = nullptr; 2165ffd83dbSDimitry Andric 217e8d8bef9SDimitry Andric // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to 218e8d8bef9SDimitry Andric // this loop. The scalar tripcount corresponds the number of elements 219e8d8bef9SDimitry Andric // processed by the loop, so we will refer to that from this point on. 220e8d8bef9SDimitry Andric if (!SE->isLoopInvariant(EC, L)) { 221e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n"); 22206c3fb27SDimitry Andric return nullptr; 22306c3fb27SDimitry Andric } 22406c3fb27SDimitry Andric 22506c3fb27SDimitry Andric // 2) Find out if IV is an induction phi. Note that we can't use Loop 22606c3fb27SDimitry Andric // helpers here to get the induction variable, because the hardware loop is 22706c3fb27SDimitry Andric // no longer in loopsimplify form, and also the hwloop intrinsic uses a 22806c3fb27SDimitry Andric // different counter. Using SCEV, we check that the induction is of the 22906c3fb27SDimitry Andric // form i = i + 4, where the increment must be equal to the VectorWidth. 23006c3fb27SDimitry Andric auto *IV = ActiveLaneMask->getOperand(0); 23106c3fb27SDimitry Andric auto *IVExpr = SE->getSCEV(IV); 23206c3fb27SDimitry Andric auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr); 23306c3fb27SDimitry Andric 23406c3fb27SDimitry Andric if (!AddExpr) { 23506c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump()); 23606c3fb27SDimitry Andric return nullptr; 23706c3fb27SDimitry Andric } 23806c3fb27SDimitry Andric // Check that this AddRec is associated with this loop. 23906c3fb27SDimitry Andric if (AddExpr->getLoop() != L) { 24006c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n"); 24106c3fb27SDimitry Andric return nullptr; 24206c3fb27SDimitry Andric } 24306c3fb27SDimitry Andric auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1)); 24406c3fb27SDimitry Andric if (!Step) { 24506c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: "; 24606c3fb27SDimitry Andric AddExpr->getOperand(1)->dump()); 24706c3fb27SDimitry Andric return nullptr; 24806c3fb27SDimitry Andric } 24906c3fb27SDimitry Andric auto StepValue = Step->getValue()->getSExtValue(); 25006c3fb27SDimitry Andric if (VectorWidth != StepValue) { 25106c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue 25206c3fb27SDimitry Andric << " doesn't match vector width " << VectorWidth << "\n"); 25306c3fb27SDimitry Andric return nullptr; 2545ffd83dbSDimitry Andric } 2555ffd83dbSDimitry Andric 256e8d8bef9SDimitry Andric if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) { 257e8d8bef9SDimitry Andric ConstantInt *TC = dyn_cast<ConstantInt>(TripCount); 258e8d8bef9SDimitry Andric if (!TC) { 259e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in " 260e8d8bef9SDimitry Andric "set.loop.iterations\n"); 26106c3fb27SDimitry Andric return nullptr; 262e8d8bef9SDimitry Andric } 263e8d8bef9SDimitry Andric 264e8d8bef9SDimitry Andric // Calculate 2 tripcount values and check that they are consistent with 265e8d8bef9SDimitry Andric // each other. The TripCount for a predicated vector loop body is 266e8d8bef9SDimitry Andric // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we 267e8d8bef9SDimitry Andric // work it out here. 268e8d8bef9SDimitry Andric uint64_t TC1 = TC->getZExtValue(); 269e8d8bef9SDimitry Andric uint64_t TC2 = 270e8d8bef9SDimitry Andric (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth; 271e8d8bef9SDimitry Andric 272e8d8bef9SDimitry Andric // If the tripcount values are inconsistent, we can't insert the VCTP and 273e8d8bef9SDimitry Andric // trigger tail-predication; keep the intrinsic as a get.active.lane.mask 274e8d8bef9SDimitry Andric // and legalize this. 275e8d8bef9SDimitry Andric if (TC1 != TC2) { 276e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: " 277e8d8bef9SDimitry Andric << TC1 << " from set.loop.iterations, and " 278e8d8bef9SDimitry Andric << TC2 << " from get.active.lane.mask\n"); 27906c3fb27SDimitry Andric return nullptr; 280e8d8bef9SDimitry Andric } 281e8d8bef9SDimitry Andric } else if (!ForceTailPredication) { 28206c3fb27SDimitry Andric // 3) We need to prove that the sub expression that we create in the 283e8d8bef9SDimitry Andric // tail-predicated loop body, which calculates the remaining elements to be 284e8d8bef9SDimitry Andric // processed, is non-negative, i.e. it doesn't overflow: 2855ffd83dbSDimitry Andric // 286e8d8bef9SDimitry Andric // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0 2875ffd83dbSDimitry Andric // 288e8d8bef9SDimitry Andric // This is true if: 2895ffd83dbSDimitry Andric // 290e8d8bef9SDimitry Andric // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth 2915ffd83dbSDimitry Andric // 292e8d8bef9SDimitry Andric // which what we will be using here. 2935ffd83dbSDimitry Andric // 294e8d8bef9SDimitry Andric auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth)); 295e8d8bef9SDimitry Andric // ElementCount + (VW-1): 29606c3fb27SDimitry Andric auto *Start = AddExpr->getStart(); 297e8d8bef9SDimitry Andric auto *ECPlusVWMinus1 = SE->getAddExpr(EC, 2985ffd83dbSDimitry Andric SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); 2995ffd83dbSDimitry Andric 300e8d8bef9SDimitry Andric // Ceil = ElementCount + (VW-1) / VW 301e8d8bef9SDimitry Andric auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW); 302e8d8bef9SDimitry Andric 303e8d8bef9SDimitry Andric // Prevent unused variable warnings with TC 304e8d8bef9SDimitry Andric (void)TC; 30506c3fb27SDimitry Andric LLVM_DEBUG({ 306e8d8bef9SDimitry Andric dbgs() << "ARM TP: Analysing overflow behaviour for:\n"; 30706c3fb27SDimitry Andric dbgs() << "ARM TP: - TripCount = " << *TC << "\n"; 30806c3fb27SDimitry Andric dbgs() << "ARM TP: - ElemCount = " << *EC << "\n"; 30906c3fb27SDimitry Andric dbgs() << "ARM TP: - Start = " << *Start << "\n"; 31006c3fb27SDimitry Andric dbgs() << "ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) << "\n"; 311e8d8bef9SDimitry Andric dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n"; 31206c3fb27SDimitry Andric dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil << "\n"; 31306c3fb27SDimitry Andric }); 314e8d8bef9SDimitry Andric 315e8d8bef9SDimitry Andric // As an example, almost all the tripcount expressions (produced by the 316e8d8bef9SDimitry Andric // vectoriser) look like this: 317e8d8bef9SDimitry Andric // 31806c3fb27SDimitry Andric // TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4) 319e8d8bef9SDimitry Andric // 320e8d8bef9SDimitry Andric // and "ElementCount + (VW-1) / VW": 321e8d8bef9SDimitry Andric // 322e8d8bef9SDimitry Andric // Ceil = ((3 + %N) /u 4) 323e8d8bef9SDimitry Andric // 324e8d8bef9SDimitry Andric // Check for equality of TC and Ceil by calculating SCEV expression 325e8d8bef9SDimitry Andric // TC - Ceil and test it for zero. 326e8d8bef9SDimitry Andric // 32706c3fb27SDimitry Andric const SCEV *Div = SE->getUDivExpr( 32806c3fb27SDimitry Andric SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW), 32906c3fb27SDimitry Andric SE->getNegativeSCEV(Start)), 33006c3fb27SDimitry Andric VW); 33106c3fb27SDimitry Andric const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div); 33206c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: - Sub = "; Sub->dump()); 333e8d8bef9SDimitry Andric 334349cc55cSDimitry Andric // Use context sensitive facts about the path to the loop to refine. This 335349cc55cSDimitry Andric // comes up as the backedge taken count can incorporate context sensitive 336349cc55cSDimitry Andric // reasoning, and our RHS just above doesn't. 337349cc55cSDimitry Andric Sub = SE->applyLoopGuards(Sub, L); 33806c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump()); 339349cc55cSDimitry Andric 340349cc55cSDimitry Andric if (!Sub->isZero()) { 341e8d8bef9SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n"); 34206c3fb27SDimitry Andric return nullptr; 3435ffd83dbSDimitry Andric } 344e8d8bef9SDimitry Andric } 3455ffd83dbSDimitry Andric 34606c3fb27SDimitry Andric // Check that the start value is a multiple of the VectorWidth. 34706c3fb27SDimitry Andric // TODO: This could do with a method to check if the scev is a multiple of 34806c3fb27SDimitry Andric // VectorWidth. For the moment we just check for constants, muls and unknowns 34906c3fb27SDimitry Andric // (which use MaskedValueIsZero and seems to be the most common). 35006c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) { 35106c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0) 35206c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC); 35306c3fb27SDimitry Andric } else if (auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) { 35406c3fb27SDimitry Andric Type *Ty = BaseV->getType(); 35506c3fb27SDimitry Andric APInt Mask = APInt::getLowBitsSet(Ty->getPrimitiveSizeInBits(), 35606c3fb27SDimitry Andric Log2_64(VectorWidth)); 35706c3fb27SDimitry Andric if (MaskedValueIsZero(BaseV->getValue(), Mask, 358*0fca6ea1SDimitry Andric L->getHeader()->getDataLayout())) 35906c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseV); 36006c3fb27SDimitry Andric } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) { 36106c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0))) 36206c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0) 36306c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC); 36406c3fb27SDimitry Andric if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1))) 36506c3fb27SDimitry Andric if (BaseC->getAPInt().urem(VectorWidth) == 0) 36606c3fb27SDimitry Andric return SE->getMinusSCEV(EC, BaseC); 36706c3fb27SDimitry Andric } 368e8d8bef9SDimitry Andric 36906c3fb27SDimitry Andric LLVM_DEBUG( 37006c3fb27SDimitry Andric dbgs() << "ARM TP: induction base is not know to be a multiple of VF: " 37106c3fb27SDimitry Andric << *AddExpr->getOperand(0) << "\n"); 37206c3fb27SDimitry Andric return nullptr; 3735ffd83dbSDimitry Andric } 3745ffd83dbSDimitry Andric 3755ffd83dbSDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, 37606c3fb27SDimitry Andric Value *Start) { 3775ffd83dbSDimitry Andric IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); 378480093f4SDimitry Andric Module *M = L->getHeader()->getModule(); 379480093f4SDimitry Andric Type *Ty = IntegerType::get(M->getContext(), 32); 380e8d8bef9SDimitry Andric unsigned VectorWidth = 381e8d8bef9SDimitry Andric cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements(); 3828bcb0991SDimitry Andric 383480093f4SDimitry Andric // Insert a phi to count the number of elements processed by the loop. 3845f757f3fSDimitry Andric Builder.SetInsertPoint(L->getHeader(), L->getHeader()->getFirstNonPHIIt()); 385480093f4SDimitry Andric PHINode *Processed = Builder.CreatePHI(Ty, 2); 38606c3fb27SDimitry Andric Processed->addIncoming(Start, L->getLoopPreheader()); 387480093f4SDimitry Andric 388e8d8bef9SDimitry Andric // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and 389e8d8bef9SDimitry Andric // thus represent the effect of tail predication. 3905ffd83dbSDimitry Andric Builder.SetInsertPoint(ActiveLaneMask); 391e8d8bef9SDimitry Andric ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth); 392480093f4SDimitry Andric 393480093f4SDimitry Andric Intrinsic::ID VCTPID; 3945ffd83dbSDimitry Andric switch (VectorWidth) { 395480093f4SDimitry Andric default: 396480093f4SDimitry Andric llvm_unreachable("unexpected number of lanes"); 3970eae32dcSDimitry Andric case 2: VCTPID = Intrinsic::arm_mve_vctp64; break; 398480093f4SDimitry Andric case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 399480093f4SDimitry Andric case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 400480093f4SDimitry Andric case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 401480093f4SDimitry Andric } 402480093f4SDimitry Andric Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 4035ffd83dbSDimitry Andric Value *VCTPCall = Builder.CreateCall(VCTP, Processed); 4045ffd83dbSDimitry Andric ActiveLaneMask->replaceAllUsesWith(VCTPCall); 405480093f4SDimitry Andric 406480093f4SDimitry Andric // Add the incoming value to the new phi. 407480093f4SDimitry Andric // TODO: This add likely already exists in the loop. 408480093f4SDimitry Andric Value *Remaining = Builder.CreateSub(Processed, Factor); 409480093f4SDimitry Andric Processed->addIncoming(Remaining, L->getLoopLatch()); 410480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 411480093f4SDimitry Andric << *Processed << "\n" 4125ffd83dbSDimitry Andric << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n"); 413480093f4SDimitry Andric } 414480093f4SDimitry Andric 415e8d8bef9SDimitry Andric bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) { 416e8d8bef9SDimitry Andric SmallVector<IntrinsicInst *, 4> ActiveLaneMasks; 417e8d8bef9SDimitry Andric for (auto *BB : L->getBlocks()) 418e8d8bef9SDimitry Andric for (auto &I : *BB) 419e8d8bef9SDimitry Andric if (auto *Int = dyn_cast<IntrinsicInst>(&I)) 420e8d8bef9SDimitry Andric if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask) 421e8d8bef9SDimitry Andric ActiveLaneMasks.push_back(Int); 422e8d8bef9SDimitry Andric 423e8d8bef9SDimitry Andric if (ActiveLaneMasks.empty()) 424480093f4SDimitry Andric return false; 425480093f4SDimitry Andric 426480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 4278bcb0991SDimitry Andric 428e8d8bef9SDimitry Andric for (auto *ActiveLaneMask : ActiveLaneMasks) { 4295ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: " 4305ffd83dbSDimitry Andric << *ActiveLaneMask << "\n"); 4318bcb0991SDimitry Andric 43206c3fb27SDimitry Andric const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount); 43306c3fb27SDimitry Andric if (!StartSCEV) { 4345ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n"); 4355ffd83dbSDimitry Andric return false; 4365ffd83dbSDimitry Andric } 43706c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP. Start is " << *StartSCEV 43806c3fb27SDimitry Andric << "\n"); 439*0fca6ea1SDimitry Andric SCEVExpander Expander(*SE, L->getHeader()->getDataLayout(), 44006c3fb27SDimitry Andric "start"); 44106c3fb27SDimitry Andric Instruction *Ins = L->getLoopPreheader()->getTerminator(); 44206c3fb27SDimitry Andric Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->getType(), Ins); 44306c3fb27SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Created start value " << *Start << "\n"); 44406c3fb27SDimitry Andric InsertVCTPIntrinsic(ActiveLaneMask, Start); 4458bcb0991SDimitry Andric } 4468bcb0991SDimitry Andric 447e8d8bef9SDimitry Andric // Remove dead instructions and now dead phis. 448e8d8bef9SDimitry Andric for (auto *II : ActiveLaneMasks) 449e8d8bef9SDimitry Andric RecursivelyDeleteTriviallyDeadInstructions(II); 450bdd1243dSDimitry Andric for (auto *I : L->blocks()) 451e8d8bef9SDimitry Andric DeleteDeadPHIs(I); 4528bcb0991SDimitry Andric return true; 4538bcb0991SDimitry Andric } 4548bcb0991SDimitry Andric 4558bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() { 4568bcb0991SDimitry Andric return new MVETailPredication(); 4578bcb0991SDimitry Andric } 4588bcb0991SDimitry Andric 4598bcb0991SDimitry Andric char MVETailPredication::ID = 0; 4608bcb0991SDimitry Andric 4618bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 4628bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 463