1*5ffd83dbSDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===// 28bcb0991SDimitry Andric // 38bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 58bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68bcb0991SDimitry Andric // 78bcb0991SDimitry Andric //===----------------------------------------------------------------------===// 88bcb0991SDimitry Andric // 98bcb0991SDimitry Andric /// \file 108bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11*5ffd83dbSDimitry Andric /// branches to help accelerate DSP applications. These two extensions, 12*5ffd83dbSDimitry Andric /// combined with a new form of predication called tail-predication, can be used 13*5ffd83dbSDimitry Andric /// to provide implicit vector predication within a low-overhead loop. 14*5ffd83dbSDimitry Andric /// This is implicit because the predicate of active/inactive lanes is 15*5ffd83dbSDimitry Andric /// calculated by hardware, and thus does not need to be explicitly passed 16*5ffd83dbSDimitry Andric /// to vector instructions. The instructions responsible for this are the 17*5ffd83dbSDimitry Andric /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the 18*5ffd83dbSDimitry Andric /// the total number of data elements processed by the loop. The loop-end 19*5ffd83dbSDimitry Andric /// LETP instruction is responsible for decrementing and setting the remaining 20*5ffd83dbSDimitry Andric /// elements to be processed and generating the mask of active lanes. 21*5ffd83dbSDimitry Andric /// 228bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the 238bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is 248bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are 258bcb0991SDimitry Andric /// predicated upon the iteration counter. This pass looks at these predicated 268bcb0991SDimitry Andric /// vector loops, that are targets for low-overhead loops, and prepares it for 278bcb0991SDimitry Andric /// code generation. Once the vectorizer has produced a masked loop, there's a 288bcb0991SDimitry Andric /// couple of final forms: 298bcb0991SDimitry Andric /// - A tail-predicated loop, with implicit predication. 308bcb0991SDimitry Andric /// - A loop containing multiple VCPT instructions, predicating multiple VPT 318bcb0991SDimitry Andric /// blocks of instructions operating on different vector types. 32480093f4SDimitry Andric /// 33*5ffd83dbSDimitry Andric /// This pass: 34*5ffd83dbSDimitry Andric /// 1) Checks if the predicates of the masked load/store instructions are 35*5ffd83dbSDimitry Andric /// generated by intrinsic @llvm.get.active.lanes(). This intrinsic consumes 36*5ffd83dbSDimitry Andric /// the Backedge Taken Count (BTC) of the scalar loop as its second argument, 37*5ffd83dbSDimitry Andric /// which we extract to set up the number of elements processed by the loop. 38*5ffd83dbSDimitry Andric /// 2) Intrinsic @llvm.get.active.lanes() is then replaced by the MVE target 39*5ffd83dbSDimitry Andric /// specific VCTP intrinsic to represent the effect of tail predication. 40*5ffd83dbSDimitry Andric /// This will be picked up by the ARM Low-overhead loop pass, which performs 41*5ffd83dbSDimitry Andric /// the final transformation to a DLSTP or WLSTP tail-predicated loop. 428bcb0991SDimitry Andric 43480093f4SDimitry Andric #include "ARM.h" 44480093f4SDimitry Andric #include "ARMSubtarget.h" 45*5ffd83dbSDimitry Andric #include "ARMTargetTransformInfo.h" 468bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 478bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h" 488bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h" 498bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h" 50*5ffd83dbSDimitry Andric #include "llvm/Analysis/TargetLibraryInfo.h" 518bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 528bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 538bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h" 54480093f4SDimitry Andric #include "llvm/IR/Instructions.h" 55480093f4SDimitry Andric #include "llvm/IR/IntrinsicsARM.h" 568bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h" 57*5ffd83dbSDimitry Andric #include "llvm/InitializePasses.h" 588bcb0991SDimitry Andric #include "llvm/Support/Debug.h" 598bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 60*5ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 61*5ffd83dbSDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 628bcb0991SDimitry Andric 638bcb0991SDimitry Andric using namespace llvm; 648bcb0991SDimitry Andric 658bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication" 668bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication" 678bcb0991SDimitry Andric 68*5ffd83dbSDimitry Andric cl::opt<TailPredication::Mode> EnableTailPredication( 69*5ffd83dbSDimitry Andric "tail-predication", cl::desc("MVE tail-predication options"), 70*5ffd83dbSDimitry Andric cl::init(TailPredication::Disabled), 71*5ffd83dbSDimitry Andric cl::values(clEnumValN(TailPredication::Disabled, "disabled", 72*5ffd83dbSDimitry Andric "Don't tail-predicate loops"), 73*5ffd83dbSDimitry Andric clEnumValN(TailPredication::EnabledNoReductions, 74*5ffd83dbSDimitry Andric "enabled-no-reductions", 75*5ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops"), 76*5ffd83dbSDimitry Andric clEnumValN(TailPredication::Enabled, 77*5ffd83dbSDimitry Andric "enabled", 78*5ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops"), 79*5ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabledNoReductions, 80*5ffd83dbSDimitry Andric "force-enabled-no-reductions", 81*5ffd83dbSDimitry Andric "Enable tail-predication, but not for reduction loops, " 82*5ffd83dbSDimitry Andric "and force this which might be unsafe"), 83*5ffd83dbSDimitry Andric clEnumValN(TailPredication::ForceEnabled, 84*5ffd83dbSDimitry Andric "force-enabled", 85*5ffd83dbSDimitry Andric "Enable tail-predication, including reduction loops, " 86*5ffd83dbSDimitry Andric "and force this which might be unsafe"))); 87*5ffd83dbSDimitry Andric 88*5ffd83dbSDimitry Andric 898bcb0991SDimitry Andric namespace { 908bcb0991SDimitry Andric 918bcb0991SDimitry Andric class MVETailPredication : public LoopPass { 928bcb0991SDimitry Andric SmallVector<IntrinsicInst*, 4> MaskedInsts; 938bcb0991SDimitry Andric Loop *L = nullptr; 948bcb0991SDimitry Andric ScalarEvolution *SE = nullptr; 958bcb0991SDimitry Andric TargetTransformInfo *TTI = nullptr; 96*5ffd83dbSDimitry Andric const ARMSubtarget *ST = nullptr; 978bcb0991SDimitry Andric 988bcb0991SDimitry Andric public: 998bcb0991SDimitry Andric static char ID; 1008bcb0991SDimitry Andric 1018bcb0991SDimitry Andric MVETailPredication() : LoopPass(ID) { } 1028bcb0991SDimitry Andric 1038bcb0991SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 1048bcb0991SDimitry Andric AU.addRequired<ScalarEvolutionWrapperPass>(); 1058bcb0991SDimitry Andric AU.addRequired<LoopInfoWrapperPass>(); 1068bcb0991SDimitry Andric AU.addRequired<TargetPassConfig>(); 1078bcb0991SDimitry Andric AU.addRequired<TargetTransformInfoWrapperPass>(); 1088bcb0991SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 1098bcb0991SDimitry Andric AU.setPreservesCFG(); 1108bcb0991SDimitry Andric } 1118bcb0991SDimitry Andric 1128bcb0991SDimitry Andric bool runOnLoop(Loop *L, LPPassManager&) override; 1138bcb0991SDimitry Andric 1148bcb0991SDimitry Andric private: 1158bcb0991SDimitry Andric /// Perform the relevant checks on the loop and convert if possible. 1168bcb0991SDimitry Andric bool TryConvert(Value *TripCount); 1178bcb0991SDimitry Andric 1188bcb0991SDimitry Andric /// Return whether this is a vectorized loop, that contains masked 1198bcb0991SDimitry Andric /// load/stores. 1208bcb0991SDimitry Andric bool IsPredicatedVectorLoop(); 1218bcb0991SDimitry Andric 122*5ffd83dbSDimitry Andric /// Perform checks on the arguments of @llvm.get.active.lane.mask 123*5ffd83dbSDimitry Andric /// intrinsic: check if the first is a loop induction variable, and for the 124*5ffd83dbSDimitry Andric /// the second check that no overflow can occur in the expression that use 125*5ffd83dbSDimitry Andric /// this backedge-taken count. 126*5ffd83dbSDimitry Andric bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount, 127*5ffd83dbSDimitry Andric FixedVectorType *VecTy); 128480093f4SDimitry Andric 129480093f4SDimitry Andric /// Insert the intrinsic to represent the effect of tail predication. 130*5ffd83dbSDimitry Andric void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount, 131*5ffd83dbSDimitry Andric FixedVectorType *VecTy); 132*5ffd83dbSDimitry Andric 133*5ffd83dbSDimitry Andric /// Rematerialize the iteration count in exit blocks, which enables 134*5ffd83dbSDimitry Andric /// ARMLowOverheadLoops to better optimise away loop update statements inside 135*5ffd83dbSDimitry Andric /// hardware-loops. 136*5ffd83dbSDimitry Andric void RematerializeIterCount(); 1378bcb0991SDimitry Andric }; 1388bcb0991SDimitry Andric 1398bcb0991SDimitry Andric } // end namespace 1408bcb0991SDimitry Andric 1418bcb0991SDimitry Andric static bool IsDecrement(Instruction &I) { 1428bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 1438bcb0991SDimitry Andric if (!Call) 1448bcb0991SDimitry Andric return false; 1458bcb0991SDimitry Andric 1468bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1478bcb0991SDimitry Andric return ID == Intrinsic::loop_decrement_reg; 1488bcb0991SDimitry Andric } 1498bcb0991SDimitry Andric 1508bcb0991SDimitry Andric static bool IsMasked(Instruction *I) { 1518bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(I); 1528bcb0991SDimitry Andric if (!Call) 1538bcb0991SDimitry Andric return false; 1548bcb0991SDimitry Andric 1558bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1568bcb0991SDimitry Andric // TODO: Support gather/scatter expand/compress operations. 1578bcb0991SDimitry Andric return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 1588bcb0991SDimitry Andric } 1598bcb0991SDimitry Andric 1608bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 161*5ffd83dbSDimitry Andric if (skipLoop(L) || !EnableTailPredication) 1628bcb0991SDimitry Andric return false; 1638bcb0991SDimitry Andric 164*5ffd83dbSDimitry Andric MaskedInsts.clear(); 1658bcb0991SDimitry Andric Function &F = *L->getHeader()->getParent(); 1668bcb0991SDimitry Andric auto &TPC = getAnalysis<TargetPassConfig>(); 1678bcb0991SDimitry Andric auto &TM = TPC.getTM<TargetMachine>(); 168*5ffd83dbSDimitry Andric ST = &TM.getSubtarget<ARMSubtarget>(F); 1698bcb0991SDimitry Andric TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1708bcb0991SDimitry Andric SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 1718bcb0991SDimitry Andric this->L = L; 1728bcb0991SDimitry Andric 1738bcb0991SDimitry Andric // The MVE and LOB extensions are combined to enable tail-predication, but 1748bcb0991SDimitry Andric // there's nothing preventing us from generating VCTP instructions for v8.1m. 1758bcb0991SDimitry Andric if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 176480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 1778bcb0991SDimitry Andric return false; 1788bcb0991SDimitry Andric } 1798bcb0991SDimitry Andric 1808bcb0991SDimitry Andric BasicBlock *Preheader = L->getLoopPreheader(); 1818bcb0991SDimitry Andric if (!Preheader) 1828bcb0991SDimitry Andric return false; 1838bcb0991SDimitry Andric 1848bcb0991SDimitry Andric auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 1858bcb0991SDimitry Andric for (auto &I : *BB) { 1868bcb0991SDimitry Andric auto *Call = dyn_cast<IntrinsicInst>(&I); 1878bcb0991SDimitry Andric if (!Call) 1888bcb0991SDimitry Andric continue; 1898bcb0991SDimitry Andric 1908bcb0991SDimitry Andric Intrinsic::ID ID = Call->getIntrinsicID(); 1918bcb0991SDimitry Andric if (ID == Intrinsic::set_loop_iterations || 1928bcb0991SDimitry Andric ID == Intrinsic::test_set_loop_iterations) 1938bcb0991SDimitry Andric return cast<IntrinsicInst>(&I); 1948bcb0991SDimitry Andric } 1958bcb0991SDimitry Andric return nullptr; 1968bcb0991SDimitry Andric }; 1978bcb0991SDimitry Andric 1988bcb0991SDimitry Andric // Look for the hardware loop intrinsic that sets the iteration count. 1998bcb0991SDimitry Andric IntrinsicInst *Setup = FindLoopIterations(Preheader); 2008bcb0991SDimitry Andric 2018bcb0991SDimitry Andric // The test.set iteration could live in the pre-preheader. 2028bcb0991SDimitry Andric if (!Setup) { 2038bcb0991SDimitry Andric if (!Preheader->getSinglePredecessor()) 2048bcb0991SDimitry Andric return false; 2058bcb0991SDimitry Andric Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 2068bcb0991SDimitry Andric if (!Setup) 2078bcb0991SDimitry Andric return false; 2088bcb0991SDimitry Andric } 2098bcb0991SDimitry Andric 2108bcb0991SDimitry Andric // Search for the hardware loop intrinic that decrements the loop counter. 2118bcb0991SDimitry Andric IntrinsicInst *Decrement = nullptr; 2128bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 2138bcb0991SDimitry Andric for (auto &I : *BB) { 2148bcb0991SDimitry Andric if (IsDecrement(I)) { 2158bcb0991SDimitry Andric Decrement = cast<IntrinsicInst>(&I); 2168bcb0991SDimitry Andric break; 2178bcb0991SDimitry Andric } 2188bcb0991SDimitry Andric } 2198bcb0991SDimitry Andric } 2208bcb0991SDimitry Andric 2218bcb0991SDimitry Andric if (!Decrement) 2228bcb0991SDimitry Andric return false; 2238bcb0991SDimitry Andric 224480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" 2258bcb0991SDimitry Andric << *Decrement << "\n"); 2268bcb0991SDimitry Andric 227*5ffd83dbSDimitry Andric if (!TryConvert(Setup->getArgOperand(0))) { 228*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Can't tail-predicate this loop.\n"); 2298bcb0991SDimitry Andric return false; 2308bcb0991SDimitry Andric } 2318bcb0991SDimitry Andric 232*5ffd83dbSDimitry Andric return true; 2338bcb0991SDimitry Andric } 2348bcb0991SDimitry Andric 235*5ffd83dbSDimitry Andric static FixedVectorType *getVectorType(IntrinsicInst *I) { 2368bcb0991SDimitry Andric unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 2378bcb0991SDimitry Andric auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 238*5ffd83dbSDimitry Andric auto *VecTy = cast<FixedVectorType>(PtrTy->getElementType()); 239*5ffd83dbSDimitry Andric assert(VecTy && "No scalable vectors expected here"); 240*5ffd83dbSDimitry Andric return VecTy; 2418bcb0991SDimitry Andric } 2428bcb0991SDimitry Andric 2438bcb0991SDimitry Andric bool MVETailPredication::IsPredicatedVectorLoop() { 2448bcb0991SDimitry Andric // Check that the loop contains at least one masked load/store intrinsic. 2458bcb0991SDimitry Andric // We only support 'normal' vector instructions - other than masked 2468bcb0991SDimitry Andric // load/stores. 247*5ffd83dbSDimitry Andric bool ActiveLaneMask = false; 2488bcb0991SDimitry Andric for (auto *BB : L->getBlocks()) { 2498bcb0991SDimitry Andric for (auto &I : *BB) { 250*5ffd83dbSDimitry Andric auto *Int = dyn_cast<IntrinsicInst>(&I); 251*5ffd83dbSDimitry Andric if (!Int) 252*5ffd83dbSDimitry Andric continue; 253*5ffd83dbSDimitry Andric 254*5ffd83dbSDimitry Andric switch (Int->getIntrinsicID()) { 255*5ffd83dbSDimitry Andric case Intrinsic::get_active_lane_mask: 256*5ffd83dbSDimitry Andric ActiveLaneMask = true; 257*5ffd83dbSDimitry Andric LLVM_FALLTHROUGH; 258*5ffd83dbSDimitry Andric case Intrinsic::sadd_sat: 259*5ffd83dbSDimitry Andric case Intrinsic::uadd_sat: 260*5ffd83dbSDimitry Andric case Intrinsic::ssub_sat: 261*5ffd83dbSDimitry Andric case Intrinsic::usub_sat: 262*5ffd83dbSDimitry Andric continue; 263*5ffd83dbSDimitry Andric case Intrinsic::fma: 264*5ffd83dbSDimitry Andric case Intrinsic::trunc: 265*5ffd83dbSDimitry Andric case Intrinsic::rint: 266*5ffd83dbSDimitry Andric case Intrinsic::round: 267*5ffd83dbSDimitry Andric case Intrinsic::floor: 268*5ffd83dbSDimitry Andric case Intrinsic::ceil: 269*5ffd83dbSDimitry Andric case Intrinsic::fabs: 270*5ffd83dbSDimitry Andric if (ST->hasMVEFloatOps()) 271*5ffd83dbSDimitry Andric continue; 272*5ffd83dbSDimitry Andric LLVM_FALLTHROUGH; 273*5ffd83dbSDimitry Andric default: 274*5ffd83dbSDimitry Andric break; 275*5ffd83dbSDimitry Andric } 276*5ffd83dbSDimitry Andric 2778bcb0991SDimitry Andric if (IsMasked(&I)) { 278*5ffd83dbSDimitry Andric auto *VecTy = getVectorType(Int); 2798bcb0991SDimitry Andric unsigned Lanes = VecTy->getNumElements(); 2808bcb0991SDimitry Andric unsigned ElementWidth = VecTy->getScalarSizeInBits(); 2818bcb0991SDimitry Andric // MVE vectors are 128-bit, but don't support 128 x i1. 2828bcb0991SDimitry Andric // TODO: Can we support vectors larger than 128-bits? 2838bcb0991SDimitry Andric unsigned MaxWidth = TTI->getRegisterBitWidth(true); 284480093f4SDimitry Andric if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth) 2858bcb0991SDimitry Andric return false; 2868bcb0991SDimitry Andric MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 287*5ffd83dbSDimitry Andric continue; 288*5ffd83dbSDimitry Andric } 289*5ffd83dbSDimitry Andric 290*5ffd83dbSDimitry Andric for (const Use &U : Int->args()) { 2918bcb0991SDimitry Andric if (isa<VectorType>(U->getType())) 2928bcb0991SDimitry Andric return false; 2938bcb0991SDimitry Andric } 2948bcb0991SDimitry Andric } 2958bcb0991SDimitry Andric } 2968bcb0991SDimitry Andric 297*5ffd83dbSDimitry Andric if (!ActiveLaneMask) { 298*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: No get.active.lane.mask intrinsic found.\n"); 299*5ffd83dbSDimitry Andric return false; 300*5ffd83dbSDimitry Andric } 3018bcb0991SDimitry Andric return !MaskedInsts.empty(); 3028bcb0991SDimitry Andric } 3038bcb0991SDimitry Andric 3048bcb0991SDimitry Andric // Look through the exit block to see whether there's a duplicate predicate 3058bcb0991SDimitry Andric // instruction. This can happen when we need to perform a select on values 3068bcb0991SDimitry Andric // from the last and previous iteration. Instead of doing a straight 3078bcb0991SDimitry Andric // replacement of that predicate with the vctp, clone the vctp and place it 3088bcb0991SDimitry Andric // in the block. This means that the VPR doesn't have to be live into the 3098bcb0991SDimitry Andric // exit block which should make it easier to convert this loop into a proper 3108bcb0991SDimitry Andric // tail predicated loop. 311*5ffd83dbSDimitry Andric static void Cleanup(SetVector<Instruction*> &MaybeDead, Loop *L) { 312480093f4SDimitry Andric BasicBlock *Exit = L->getUniqueExitBlock(); 313480093f4SDimitry Andric if (!Exit) { 314480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n"); 315480093f4SDimitry Andric return; 316480093f4SDimitry Andric } 317480093f4SDimitry Andric 3188bcb0991SDimitry Andric // Drop references and add operands to check for dead. 3198bcb0991SDimitry Andric SmallPtrSet<Instruction*, 4> Dead; 3208bcb0991SDimitry Andric while (!MaybeDead.empty()) { 3218bcb0991SDimitry Andric auto *I = MaybeDead.front(); 3228bcb0991SDimitry Andric MaybeDead.remove(I); 3238bcb0991SDimitry Andric if (I->hasNUsesOrMore(1)) 3248bcb0991SDimitry Andric continue; 3258bcb0991SDimitry Andric 326*5ffd83dbSDimitry Andric for (auto &U : I->operands()) 3278bcb0991SDimitry Andric if (auto *OpI = dyn_cast<Instruction>(U)) 3288bcb0991SDimitry Andric MaybeDead.insert(OpI); 329*5ffd83dbSDimitry Andric 3308bcb0991SDimitry Andric Dead.insert(I); 3318bcb0991SDimitry Andric } 3328bcb0991SDimitry Andric 333480093f4SDimitry Andric for (auto *I : Dead) { 334480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump()); 3358bcb0991SDimitry Andric I->eraseFromParent(); 336480093f4SDimitry Andric } 3378bcb0991SDimitry Andric 3388bcb0991SDimitry Andric for (auto I : L->blocks()) 3398bcb0991SDimitry Andric DeleteDeadPHIs(I); 3408bcb0991SDimitry Andric } 3418bcb0991SDimitry Andric 342*5ffd83dbSDimitry Andric // The active lane intrinsic has this form: 343*5ffd83dbSDimitry Andric // 344*5ffd83dbSDimitry Andric // @llvm.get.active.lane.mask(IV, BTC) 345*5ffd83dbSDimitry Andric // 346*5ffd83dbSDimitry Andric // Here we perform checks that this intrinsic behaves as expected, 347*5ffd83dbSDimitry Andric // which means: 348*5ffd83dbSDimitry Andric // 349*5ffd83dbSDimitry Andric // 1) The element count, which is calculated with BTC + 1, cannot overflow. 350*5ffd83dbSDimitry Andric // 2) The element count needs to be sufficiently large that the decrement of 351*5ffd83dbSDimitry Andric // element counter doesn't overflow, which means that we need to prove: 352*5ffd83dbSDimitry Andric // ceil(ElementCount / VectorWidth) >= TripCount 353*5ffd83dbSDimitry Andric // by rounding up ElementCount up: 354*5ffd83dbSDimitry Andric // ((ElementCount + (VectorWidth - 1)) / VectorWidth 355*5ffd83dbSDimitry Andric // and evaluate if expression isKnownNonNegative: 356*5ffd83dbSDimitry Andric // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount 357*5ffd83dbSDimitry Andric // 3) The IV must be an induction phi with an increment equal to the 358*5ffd83dbSDimitry Andric // vector width. 359*5ffd83dbSDimitry Andric bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, 360*5ffd83dbSDimitry Andric Value *TripCount, FixedVectorType *VecTy) { 361*5ffd83dbSDimitry Andric bool ForceTailPredication = 362*5ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabledNoReductions || 363*5ffd83dbSDimitry Andric EnableTailPredication == TailPredication::ForceEnabled; 364*5ffd83dbSDimitry Andric // 1) Test whether entry to the loop is protected by a conditional 365*5ffd83dbSDimitry Andric // BTC + 1 < 0. In other words, if the scalar trip count overflows, 366*5ffd83dbSDimitry Andric // becomes negative, we shouldn't enter the loop and creating 367*5ffd83dbSDimitry Andric // tripcount expression BTC + 1 is not safe. So, check that BTC 368*5ffd83dbSDimitry Andric // isn't max. This is evaluated in unsigned, because the semantics 369*5ffd83dbSDimitry Andric // of @get.active.lane.mask is a ULE comparison. 370*5ffd83dbSDimitry Andric 371*5ffd83dbSDimitry Andric int VectorWidth = VecTy->getNumElements(); 372*5ffd83dbSDimitry Andric auto *BackedgeTakenCount = ActiveLaneMask->getOperand(1); 373*5ffd83dbSDimitry Andric auto *BTC = SE->getSCEV(BackedgeTakenCount); 374*5ffd83dbSDimitry Andric 375*5ffd83dbSDimitry Andric if (!llvm::cannotBeMaxInLoop(BTC, L, *SE, false /*Signed*/) && 376*5ffd83dbSDimitry Andric !ForceTailPredication) { 377*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible, BTC can be max: "; 378*5ffd83dbSDimitry Andric BTC->dump()); 379*5ffd83dbSDimitry Andric return false; 380*5ffd83dbSDimitry Andric } 381*5ffd83dbSDimitry Andric 382*5ffd83dbSDimitry Andric // 2) Prove that the sub expression is non-negative, i.e. it doesn't overflow: 383*5ffd83dbSDimitry Andric // 384*5ffd83dbSDimitry Andric // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount 385*5ffd83dbSDimitry Andric // 386*5ffd83dbSDimitry Andric // 2.1) First prove overflow can't happen in: 387*5ffd83dbSDimitry Andric // 388*5ffd83dbSDimitry Andric // ElementCount + (VectorWidth - 1) 389*5ffd83dbSDimitry Andric // 390*5ffd83dbSDimitry Andric // Because of a lack of context, it is difficult to get a useful bounds on 391*5ffd83dbSDimitry Andric // this expression. But since ElementCount uses the same variables as the 392*5ffd83dbSDimitry Andric // TripCount (TC), for which we can find meaningful value ranges, we use that 393*5ffd83dbSDimitry Andric // instead and assert that: 394*5ffd83dbSDimitry Andric // 395*5ffd83dbSDimitry Andric // upperbound(TC) <= UINT_MAX - VectorWidth 396*5ffd83dbSDimitry Andric // 397*5ffd83dbSDimitry Andric auto *TC = SE->getSCEV(TripCount); 398*5ffd83dbSDimitry Andric unsigned SizeInBits = TripCount->getType()->getScalarSizeInBits(); 399*5ffd83dbSDimitry Andric auto Diff = APInt(SizeInBits, ~0) - APInt(SizeInBits, VectorWidth); 400*5ffd83dbSDimitry Andric uint64_t MaxMinusVW = Diff.getZExtValue(); 401*5ffd83dbSDimitry Andric uint64_t UpperboundTC = SE->getSignedRange(TC).getUpper().getZExtValue(); 402*5ffd83dbSDimitry Andric 403*5ffd83dbSDimitry Andric if (UpperboundTC > MaxMinusVW && !ForceTailPredication) { 404*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible in tripcount rounding:\n"; 405*5ffd83dbSDimitry Andric dbgs() << "upperbound(TC) <= UINT_MAX - VectorWidth\n"; 406*5ffd83dbSDimitry Andric dbgs() << UpperboundTC << " <= " << MaxMinusVW << "== false\n";); 407*5ffd83dbSDimitry Andric return false; 408*5ffd83dbSDimitry Andric } 409*5ffd83dbSDimitry Andric 410*5ffd83dbSDimitry Andric // 2.2) Make sure overflow doesn't happen in final expression: 411*5ffd83dbSDimitry Andric // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount, 412*5ffd83dbSDimitry Andric // To do this, compare the full ranges of these subexpressions: 413*5ffd83dbSDimitry Andric // 414*5ffd83dbSDimitry Andric // Range(Ceil) <= Range(TC) 415*5ffd83dbSDimitry Andric // 416*5ffd83dbSDimitry Andric // where Ceil = ElementCount + (VW-1) / VW. If Ceil and TC are runtime 417*5ffd83dbSDimitry Andric // values (and not constants), we have to compensate for the lowerbound value 418*5ffd83dbSDimitry Andric // range to be off by 1. The reason is that BTC lives in the preheader in 419*5ffd83dbSDimitry Andric // this form: 420*5ffd83dbSDimitry Andric // 421*5ffd83dbSDimitry Andric // %trip.count.minus = add nsw nuw i32 %N, -1 422*5ffd83dbSDimitry Andric // 423*5ffd83dbSDimitry Andric // For the loop to be executed, %N has to be >= 1 and as a result the value 424*5ffd83dbSDimitry Andric // range of %trip.count.minus has a lower bound of 0. Value %TC has this form: 425*5ffd83dbSDimitry Andric // 426*5ffd83dbSDimitry Andric // %5 = add nuw nsw i32 %4, 1 427*5ffd83dbSDimitry Andric // call void @llvm.set.loop.iterations.i32(i32 %5) 428*5ffd83dbSDimitry Andric // 429*5ffd83dbSDimitry Andric // where %5 is some expression using %N, which needs to have a lower bound of 430*5ffd83dbSDimitry Andric // 1. Thus, if the ranges of Ceil and TC are not a single constant but a set, 431*5ffd83dbSDimitry Andric // we first add 0 to TC such that we can do the <= comparison on both sets. 432*5ffd83dbSDimitry Andric // 433*5ffd83dbSDimitry Andric auto *One = SE->getOne(TripCount->getType()); 434*5ffd83dbSDimitry Andric // ElementCount = BTC + 1 435*5ffd83dbSDimitry Andric auto *ElementCount = SE->getAddExpr(BTC, One); 436*5ffd83dbSDimitry Andric // Tmp = ElementCount + (VW-1) 437*5ffd83dbSDimitry Andric auto *ECPlusVWMinus1 = SE->getAddExpr(ElementCount, 438*5ffd83dbSDimitry Andric SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); 439*5ffd83dbSDimitry Andric // Ceil = ElementCount + (VW-1) / VW 440*5ffd83dbSDimitry Andric auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, 441*5ffd83dbSDimitry Andric SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth))); 442*5ffd83dbSDimitry Andric 443*5ffd83dbSDimitry Andric ConstantRange RangeCeil = SE->getSignedRange(Ceil) ; 444*5ffd83dbSDimitry Andric ConstantRange RangeTC = SE->getSignedRange(TC) ; 445*5ffd83dbSDimitry Andric if (!RangeTC.isSingleElement()) { 446*5ffd83dbSDimitry Andric auto ZeroRange = 447*5ffd83dbSDimitry Andric ConstantRange(APInt(TripCount->getType()->getScalarSizeInBits(), 0)); 448*5ffd83dbSDimitry Andric RangeTC = RangeTC.unionWith(ZeroRange); 449*5ffd83dbSDimitry Andric } 450*5ffd83dbSDimitry Andric if (!RangeTC.contains(RangeCeil) && !ForceTailPredication) { 451*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible in sub\n"); 452*5ffd83dbSDimitry Andric return false; 453*5ffd83dbSDimitry Andric } 454*5ffd83dbSDimitry Andric 455*5ffd83dbSDimitry Andric // 3) Find out if IV is an induction phi. Note that We can't use Loop 456*5ffd83dbSDimitry Andric // helpers here to get the induction variable, because the hardware loop is 457*5ffd83dbSDimitry Andric // no longer in loopsimplify form, and also the hwloop intrinsic use a 458*5ffd83dbSDimitry Andric // different counter. Using SCEV, we check that the induction is of the 459*5ffd83dbSDimitry Andric // form i = i + 4, where the increment must be equal to the VectorWidth. 460*5ffd83dbSDimitry Andric auto *IV = ActiveLaneMask->getOperand(0); 461*5ffd83dbSDimitry Andric auto *IVExpr = SE->getSCEV(IV); 462*5ffd83dbSDimitry Andric auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr); 463*5ffd83dbSDimitry Andric if (!AddExpr) { 464*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump()); 465*5ffd83dbSDimitry Andric return false; 466*5ffd83dbSDimitry Andric } 467*5ffd83dbSDimitry Andric // Check that this AddRec is associated with this loop. 468*5ffd83dbSDimitry Andric if (AddExpr->getLoop() != L) { 469*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n"); 470*5ffd83dbSDimitry Andric return false; 471*5ffd83dbSDimitry Andric } 472*5ffd83dbSDimitry Andric auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1)); 473*5ffd83dbSDimitry Andric if (!Step) { 474*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: "; 475*5ffd83dbSDimitry Andric AddExpr->getOperand(1)->dump()); 476*5ffd83dbSDimitry Andric return false; 477*5ffd83dbSDimitry Andric } 478*5ffd83dbSDimitry Andric auto StepValue = Step->getValue()->getSExtValue(); 479*5ffd83dbSDimitry Andric if (VectorWidth == StepValue) 480*5ffd83dbSDimitry Andric return true; 481*5ffd83dbSDimitry Andric 482*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue << " doesn't match " 483*5ffd83dbSDimitry Andric "vector width " << VectorWidth << "\n"); 484*5ffd83dbSDimitry Andric 485*5ffd83dbSDimitry Andric return false; 486*5ffd83dbSDimitry Andric } 487*5ffd83dbSDimitry Andric 488*5ffd83dbSDimitry Andric // Materialize NumElements in the preheader block. 489*5ffd83dbSDimitry Andric static Value *getNumElements(BasicBlock *Preheader, Value *BTC) { 490*5ffd83dbSDimitry Andric // First, check the preheader if it not already exist: 491*5ffd83dbSDimitry Andric // 492*5ffd83dbSDimitry Andric // preheader: 493*5ffd83dbSDimitry Andric // %BTC = add i32 %N, -1 494*5ffd83dbSDimitry Andric // .. 495*5ffd83dbSDimitry Andric // vector.body: 496*5ffd83dbSDimitry Andric // 497*5ffd83dbSDimitry Andric // if %BTC already exists. We don't need to emit %NumElems = %BTC + 1, 498*5ffd83dbSDimitry Andric // but instead can just return %N. 499*5ffd83dbSDimitry Andric for (auto &I : *Preheader) { 500*5ffd83dbSDimitry Andric if (I.getOpcode() != Instruction::Add || &I != BTC) 501*5ffd83dbSDimitry Andric continue; 502*5ffd83dbSDimitry Andric ConstantInt *MinusOne = nullptr; 503*5ffd83dbSDimitry Andric if (!(MinusOne = dyn_cast<ConstantInt>(I.getOperand(1)))) 504*5ffd83dbSDimitry Andric continue; 505*5ffd83dbSDimitry Andric if (MinusOne->getSExtValue() == -1) { 506*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found num elems: " << I << "\n"); 507*5ffd83dbSDimitry Andric return I.getOperand(0); 508*5ffd83dbSDimitry Andric } 509*5ffd83dbSDimitry Andric } 510*5ffd83dbSDimitry Andric 511*5ffd83dbSDimitry Andric // But we do need to materialise BTC if it is not already there, 512*5ffd83dbSDimitry Andric // e.g. if it is a constant. 513*5ffd83dbSDimitry Andric IRBuilder<> Builder(Preheader->getTerminator()); 514*5ffd83dbSDimitry Andric Value *NumElements = Builder.CreateAdd(BTC, 515*5ffd83dbSDimitry Andric ConstantInt::get(BTC->getType(), 1), "num.elements"); 516*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Created num elems: " << *NumElements << "\n"); 517*5ffd83dbSDimitry Andric return NumElements; 518*5ffd83dbSDimitry Andric } 519*5ffd83dbSDimitry Andric 520*5ffd83dbSDimitry Andric void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, 521*5ffd83dbSDimitry Andric Value *TripCount, FixedVectorType *VecTy) { 522*5ffd83dbSDimitry Andric IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); 523480093f4SDimitry Andric Module *M = L->getHeader()->getModule(); 524480093f4SDimitry Andric Type *Ty = IntegerType::get(M->getContext(), 32); 525*5ffd83dbSDimitry Andric unsigned VectorWidth = VecTy->getNumElements(); 526*5ffd83dbSDimitry Andric 527*5ffd83dbSDimitry Andric // The backedge-taken count in @llvm.get.active.lane.mask, its 2nd operand, 528*5ffd83dbSDimitry Andric // is one less than the trip count. So we need to find or create 529*5ffd83dbSDimitry Andric // %num.elements = %BTC + 1 in the preheader. 530*5ffd83dbSDimitry Andric Value *BTC = ActiveLaneMask->getOperand(1); 531*5ffd83dbSDimitry Andric Builder.SetInsertPoint(L->getLoopPreheader()->getTerminator()); 532*5ffd83dbSDimitry Andric Value *NumElements = getNumElements(L->getLoopPreheader(), BTC); 5338bcb0991SDimitry Andric 534480093f4SDimitry Andric // Insert a phi to count the number of elements processed by the loop. 535*5ffd83dbSDimitry Andric Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI() ); 536480093f4SDimitry Andric PHINode *Processed = Builder.CreatePHI(Ty, 2); 537480093f4SDimitry Andric Processed->addIncoming(NumElements, L->getLoopPreheader()); 538480093f4SDimitry Andric 539*5ffd83dbSDimitry Andric // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and thus 540*5ffd83dbSDimitry Andric // represent the effect of tail predication. 541*5ffd83dbSDimitry Andric Builder.SetInsertPoint(ActiveLaneMask); 542480093f4SDimitry Andric ConstantInt *Factor = 543*5ffd83dbSDimitry Andric ConstantInt::get(cast<IntegerType>(Ty), VectorWidth); 544480093f4SDimitry Andric 545480093f4SDimitry Andric Intrinsic::ID VCTPID; 546*5ffd83dbSDimitry Andric switch (VectorWidth) { 547480093f4SDimitry Andric default: 548480093f4SDimitry Andric llvm_unreachable("unexpected number of lanes"); 549480093f4SDimitry Andric case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 550480093f4SDimitry Andric case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 551480093f4SDimitry Andric case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 552480093f4SDimitry Andric 553480093f4SDimitry Andric // FIXME: vctp64 currently not supported because the predicate 554480093f4SDimitry Andric // vector wants to be <2 x i1>, but v2i1 is not a legal MVE 555480093f4SDimitry Andric // type, so problems happen at isel time. 556480093f4SDimitry Andric // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics 557480093f4SDimitry Andric // purposes, but takes a v4i1 instead of a v2i1. 558480093f4SDimitry Andric } 559480093f4SDimitry Andric Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 560*5ffd83dbSDimitry Andric Value *VCTPCall = Builder.CreateCall(VCTP, Processed); 561*5ffd83dbSDimitry Andric ActiveLaneMask->replaceAllUsesWith(VCTPCall); 562480093f4SDimitry Andric 563480093f4SDimitry Andric // Add the incoming value to the new phi. 564480093f4SDimitry Andric // TODO: This add likely already exists in the loop. 565480093f4SDimitry Andric Value *Remaining = Builder.CreateSub(Processed, Factor); 566480093f4SDimitry Andric Processed->addIncoming(Remaining, L->getLoopLatch()); 567480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 568480093f4SDimitry Andric << *Processed << "\n" 569*5ffd83dbSDimitry Andric << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n"); 570480093f4SDimitry Andric } 571480093f4SDimitry Andric 572480093f4SDimitry Andric bool MVETailPredication::TryConvert(Value *TripCount) { 573480093f4SDimitry Andric if (!IsPredicatedVectorLoop()) { 574*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n"); 575480093f4SDimitry Andric return false; 576480093f4SDimitry Andric } 577480093f4SDimitry Andric 578480093f4SDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 579*5ffd83dbSDimitry Andric SetVector<Instruction*> Predicates; 5808bcb0991SDimitry Andric 5818bcb0991SDimitry Andric // Walk through the masked intrinsics and try to find whether the predicate 582*5ffd83dbSDimitry Andric // operand is generated by intrinsic @llvm.get.active.lane.mask(). 5838bcb0991SDimitry Andric for (auto *I : MaskedInsts) { 584*5ffd83dbSDimitry Andric unsigned PredOp = I->getIntrinsicID() == Intrinsic::masked_load ? 2 : 3; 5858bcb0991SDimitry Andric auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 5868bcb0991SDimitry Andric if (!Predicate || Predicates.count(Predicate)) 5878bcb0991SDimitry Andric continue; 5888bcb0991SDimitry Andric 589*5ffd83dbSDimitry Andric auto *ActiveLaneMask = dyn_cast<IntrinsicInst>(Predicate); 590*5ffd83dbSDimitry Andric if (!ActiveLaneMask || 591*5ffd83dbSDimitry Andric ActiveLaneMask->getIntrinsicID() != Intrinsic::get_active_lane_mask) 5928bcb0991SDimitry Andric continue; 5938bcb0991SDimitry Andric 5948bcb0991SDimitry Andric Predicates.insert(Predicate); 595*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: " 596*5ffd83dbSDimitry Andric << *ActiveLaneMask << "\n"); 5978bcb0991SDimitry Andric 598*5ffd83dbSDimitry Andric auto *VecTy = getVectorType(I); 599*5ffd83dbSDimitry Andric if (!IsSafeActiveMask(ActiveLaneMask, TripCount, VecTy)) { 600*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n"); 601*5ffd83dbSDimitry Andric return false; 602*5ffd83dbSDimitry Andric } 603*5ffd83dbSDimitry Andric LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP.\n"); 604*5ffd83dbSDimitry Andric InsertVCTPIntrinsic(ActiveLaneMask, TripCount, VecTy); 6058bcb0991SDimitry Andric } 6068bcb0991SDimitry Andric 607*5ffd83dbSDimitry Andric Cleanup(Predicates, L); 6088bcb0991SDimitry Andric return true; 6098bcb0991SDimitry Andric } 6108bcb0991SDimitry Andric 6118bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() { 6128bcb0991SDimitry Andric return new MVETailPredication(); 6138bcb0991SDimitry Andric } 6148bcb0991SDimitry Andric 6158bcb0991SDimitry Andric char MVETailPredication::ID = 0; 6168bcb0991SDimitry Andric 6178bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 6188bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 619