1 //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11 /// branches to help accelerate DSP applications. These two extensions can be 12 /// combined to provide implicit vector predication within a low-overhead loop. 13 /// The HardwareLoops pass inserts intrinsics identifying loops that the 14 /// backend will attempt to convert into a low-overhead loop. The vectorizer is 15 /// responsible for generating a vectorized loop in which the lanes are 16 /// predicated upon the iteration counter. This pass looks at these predicated 17 /// vector loops, that are targets for low-overhead loops, and prepares it for 18 /// code generation. Once the vectorizer has produced a masked loop, there's a 19 /// couple of final forms: 20 /// - A tail-predicated loop, with implicit predication. 21 /// - A loop containing multiple VCPT instructions, predicating multiple VPT 22 /// blocks of instructions operating on different vector types. 23 /// 24 /// This pass inserts the inserts the VCTP intrinsic to represent the effect of 25 /// tail predication. This will be picked up by the ARM Low-overhead loop pass, 26 /// which performs the final transformation to a DLSTP or WLSTP tail-predicated 27 /// loop. 28 29 #include "ARM.h" 30 #include "ARMSubtarget.h" 31 #include "llvm/Analysis/LoopInfo.h" 32 #include "llvm/Analysis/LoopPass.h" 33 #include "llvm/Analysis/ScalarEvolution.h" 34 #include "llvm/Analysis/ScalarEvolutionExpander.h" 35 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 36 #include "llvm/Analysis/TargetTransformInfo.h" 37 #include "llvm/CodeGen/TargetPassConfig.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Instructions.h" 41 #include "llvm/IR/IntrinsicsARM.h" 42 #include "llvm/IR/PatternMatch.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/LoopUtils.h" 46 47 using namespace llvm; 48 49 #define DEBUG_TYPE "mve-tail-predication" 50 #define DESC "Transform predicated vector loops to use MVE tail predication" 51 52 cl::opt<bool> 53 DisableTailPredication("disable-mve-tail-predication", cl::Hidden, 54 cl::init(true), 55 cl::desc("Disable MVE Tail Predication")); 56 namespace { 57 58 // Bookkeeping for pattern matching the loop trip count and the number of 59 // elements processed by the loop. 60 struct TripCountPattern { 61 // The Predicate used by the masked loads/stores, i.e. an icmp instruction 62 // which calculates active/inactive lanes 63 Instruction *Predicate = nullptr; 64 65 // The add instruction that increments the IV 66 Value *TripCount = nullptr; 67 68 // The number of elements processed by the vector loop. 69 Value *NumElements = nullptr; 70 71 VectorType *VecTy = nullptr; 72 Instruction *Shuffle = nullptr; 73 Instruction *Induction = nullptr; 74 75 TripCountPattern(Instruction *P, Value *TC, VectorType *VT) 76 : Predicate(P), TripCount(TC), VecTy(VT){}; 77 }; 78 79 class MVETailPredication : public LoopPass { 80 SmallVector<IntrinsicInst*, 4> MaskedInsts; 81 Loop *L = nullptr; 82 LoopInfo *LI = nullptr; 83 const DataLayout *DL; 84 DominatorTree *DT = nullptr; 85 ScalarEvolution *SE = nullptr; 86 TargetTransformInfo *TTI = nullptr; 87 TargetLibraryInfo *TLI = nullptr; 88 bool ClonedVCTPInExitBlock = false; 89 90 public: 91 static char ID; 92 93 MVETailPredication() : LoopPass(ID) { } 94 95 void getAnalysisUsage(AnalysisUsage &AU) const override { 96 AU.addRequired<ScalarEvolutionWrapperPass>(); 97 AU.addRequired<LoopInfoWrapperPass>(); 98 AU.addRequired<TargetPassConfig>(); 99 AU.addRequired<TargetTransformInfoWrapperPass>(); 100 AU.addRequired<DominatorTreeWrapperPass>(); 101 AU.addRequired<TargetLibraryInfoWrapperPass>(); 102 AU.addPreserved<LoopInfoWrapperPass>(); 103 AU.setPreservesCFG(); 104 } 105 106 bool runOnLoop(Loop *L, LPPassManager&) override; 107 108 private: 109 /// Perform the relevant checks on the loop and convert if possible. 110 bool TryConvert(Value *TripCount); 111 112 /// Return whether this is a vectorized loop, that contains masked 113 /// load/stores. 114 bool IsPredicatedVectorLoop(); 115 116 /// Compute a value for the total number of elements that the predicated 117 /// loop will process if it is a runtime value. 118 bool ComputeRuntimeElements(TripCountPattern &TCP); 119 120 /// Is the icmp that generates an i1 vector, based upon a loop counter 121 /// and a limit that is defined outside the loop. 122 bool isTailPredicate(TripCountPattern &TCP); 123 124 /// Insert the intrinsic to represent the effect of tail predication. 125 void InsertVCTPIntrinsic(TripCountPattern &TCP, 126 DenseMap<Instruction *, Instruction *> &NewPredicates); 127 128 /// Rematerialize the iteration count in exit blocks, which enables 129 /// ARMLowOverheadLoops to better optimise away loop update statements inside 130 /// hardware-loops. 131 void RematerializeIterCount(); 132 }; 133 134 } // end namespace 135 136 static bool IsDecrement(Instruction &I) { 137 auto *Call = dyn_cast<IntrinsicInst>(&I); 138 if (!Call) 139 return false; 140 141 Intrinsic::ID ID = Call->getIntrinsicID(); 142 return ID == Intrinsic::loop_decrement_reg; 143 } 144 145 static bool IsMasked(Instruction *I) { 146 auto *Call = dyn_cast<IntrinsicInst>(I); 147 if (!Call) 148 return false; 149 150 Intrinsic::ID ID = Call->getIntrinsicID(); 151 // TODO: Support gather/scatter expand/compress operations. 152 return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 153 } 154 155 void MVETailPredication::RematerializeIterCount() { 156 SmallVector<WeakTrackingVH, 16> DeadInsts; 157 SCEVExpander Rewriter(*SE, *DL, "mvetp"); 158 ReplaceExitVal ReplaceExitValue = AlwaysRepl; 159 160 formLCSSARecursively(*L, *DT, LI, SE); 161 rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT, ReplaceExitValue, 162 DeadInsts); 163 } 164 165 bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 166 if (skipLoop(L) || DisableTailPredication) 167 return false; 168 169 MaskedInsts.clear(); 170 Function &F = *L->getHeader()->getParent(); 171 auto &TPC = getAnalysis<TargetPassConfig>(); 172 auto &TM = TPC.getTM<TargetMachine>(); 173 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 174 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 175 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 176 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 177 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 178 auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); 179 TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr; 180 DL = &L->getHeader()->getModule()->getDataLayout(); 181 this->L = L; 182 183 // The MVE and LOB extensions are combined to enable tail-predication, but 184 // there's nothing preventing us from generating VCTP instructions for v8.1m. 185 if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 186 LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 187 return false; 188 } 189 190 BasicBlock *Preheader = L->getLoopPreheader(); 191 if (!Preheader) 192 return false; 193 194 auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 195 for (auto &I : *BB) { 196 auto *Call = dyn_cast<IntrinsicInst>(&I); 197 if (!Call) 198 continue; 199 200 Intrinsic::ID ID = Call->getIntrinsicID(); 201 if (ID == Intrinsic::set_loop_iterations || 202 ID == Intrinsic::test_set_loop_iterations) 203 return cast<IntrinsicInst>(&I); 204 } 205 return nullptr; 206 }; 207 208 // Look for the hardware loop intrinsic that sets the iteration count. 209 IntrinsicInst *Setup = FindLoopIterations(Preheader); 210 211 // The test.set iteration could live in the pre-preheader. 212 if (!Setup) { 213 if (!Preheader->getSinglePredecessor()) 214 return false; 215 Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 216 if (!Setup) 217 return false; 218 } 219 220 // Search for the hardware loop intrinic that decrements the loop counter. 221 IntrinsicInst *Decrement = nullptr; 222 for (auto *BB : L->getBlocks()) { 223 for (auto &I : *BB) { 224 if (IsDecrement(I)) { 225 Decrement = cast<IntrinsicInst>(&I); 226 break; 227 } 228 } 229 } 230 231 if (!Decrement) 232 return false; 233 234 ClonedVCTPInExitBlock = false; 235 LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" 236 << *Decrement << "\n"); 237 238 if (TryConvert(Setup->getArgOperand(0))) { 239 if (ClonedVCTPInExitBlock) 240 RematerializeIterCount(); 241 return true; 242 } 243 244 return false; 245 } 246 247 // Pattern match predicates/masks and determine if they use the loop induction 248 // variable to control the number of elements processed by the loop. If so, 249 // the loop is a candidate for tail-predication. 250 bool MVETailPredication::isTailPredicate(TripCountPattern &TCP) { 251 using namespace PatternMatch; 252 253 // Pattern match the loop body and find the add with takes the index iv 254 // and adds a constant vector to it: 255 // 256 // vector.body: 257 // .. 258 // %index = phi i32 259 // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 260 // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 261 // <4 x i32> undef, 262 // <4 x i32> zeroinitializer 263 // %induction = [add|or] <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 264 // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 265 // 266 // Please note that the 'or' is equivalent to the 'and' here, this relies on 267 // BroadcastSplat being the IV which we know is a phi with 0 start and Lanes 268 // increment, which is all being checked below. 269 Instruction *BroadcastSplat = nullptr; 270 Constant *Const = nullptr; 271 if (!match(TCP.Induction, 272 m_Add(m_Instruction(BroadcastSplat), m_Constant(Const))) && 273 !match(TCP.Induction, 274 m_Or(m_Instruction(BroadcastSplat), m_Constant(Const)))) 275 return false; 276 277 // Check that we're adding <0, 1, 2, 3... 278 if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 279 for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 280 if (CDS->getElementAsInteger(i) != i) 281 return false; 282 } 283 } else 284 return false; 285 286 Instruction *Insert = nullptr; 287 // The shuffle which broadcasts the index iv into a vector. 288 if (!match(BroadcastSplat, 289 m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_ZeroMask()))) 290 return false; 291 292 // The insert element which initialises a vector with the index iv. 293 Instruction *IV = nullptr; 294 if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 295 return false; 296 297 // The index iv. 298 auto *Phi = dyn_cast<PHINode>(IV); 299 if (!Phi) 300 return false; 301 302 // TODO: Don't think we need to check the entry value. 303 Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 304 if (!match(OnEntry, m_Zero())) 305 return false; 306 307 Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 308 unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 309 310 Instruction *LHS = nullptr; 311 if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 312 return false; 313 314 return LHS == Phi; 315 } 316 317 static VectorType *getVectorType(IntrinsicInst *I) { 318 unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 319 auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 320 return cast<VectorType>(PtrTy->getElementType()); 321 } 322 323 bool MVETailPredication::IsPredicatedVectorLoop() { 324 // Check that the loop contains at least one masked load/store intrinsic. 325 // We only support 'normal' vector instructions - other than masked 326 // load/stores. 327 for (auto *BB : L->getBlocks()) { 328 for (auto &I : *BB) { 329 if (IsMasked(&I)) { 330 VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 331 unsigned Lanes = VecTy->getNumElements(); 332 unsigned ElementWidth = VecTy->getScalarSizeInBits(); 333 // MVE vectors are 128-bit, but don't support 128 x i1. 334 // TODO: Can we support vectors larger than 128-bits? 335 unsigned MaxWidth = TTI->getRegisterBitWidth(true); 336 if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth) 337 return false; 338 MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 339 } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 340 for (auto &U : Int->args()) { 341 if (isa<VectorType>(U->getType())) 342 return false; 343 } 344 } 345 } 346 } 347 348 return !MaskedInsts.empty(); 349 } 350 351 // Pattern match the predicate, which is an icmp with a constant vector of this 352 // form: 353 // 354 // icmp ult <4 x i32> %induction, <i32 32002, i32 32002, i32 32002, i32 32002> 355 // 356 // and return the constant, i.e. 32002 in this example. This is assumed to be 357 // the scalar loop iteration count: the number of loop elements by the 358 // the vector loop. Further checks are performed in function isTailPredicate(), 359 // to verify 'induction' behaves as an induction variable. 360 // 361 static bool ComputeConstElements(TripCountPattern &TCP) { 362 if (!dyn_cast<ConstantInt>(TCP.TripCount)) 363 return false; 364 365 ConstantInt *VF = ConstantInt::get( 366 cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements()); 367 using namespace PatternMatch; 368 CmpInst::Predicate CC; 369 370 if (!match(TCP.Predicate, m_ICmp(CC, m_Instruction(TCP.Induction), 371 m_AnyIntegralConstant())) || 372 CC != ICmpInst::ICMP_ULT) 373 return false; 374 375 LLVM_DEBUG(dbgs() << "ARM TP: icmp with constants: "; TCP.Predicate->dump();); 376 Value *ConstVec = TCP.Predicate->getOperand(1); 377 378 auto *CDS = dyn_cast<ConstantDataSequential>(ConstVec); 379 if (!CDS || CDS->getNumElements() != VF->getSExtValue()) 380 return false; 381 382 if ((TCP.NumElements = CDS->getSplatValue())) { 383 assert(dyn_cast<ConstantInt>(TCP.NumElements)->getSExtValue() % 384 VF->getSExtValue() != 385 0 && 386 "tail-predication: trip count should not be a multiple of the VF"); 387 LLVM_DEBUG(dbgs() << "ARM TP: Found const elem count: " << *TCP.NumElements 388 << "\n"); 389 return true; 390 } 391 return false; 392 } 393 394 // Pattern match the loop iteration count setup: 395 // 396 // %trip.count.minus.1 = add i32 %N, -1 397 // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 398 // i32 %trip.count.minus.1, i32 0 399 // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 400 // <4 x i32> undef, 401 // <4 x i32> zeroinitializer 402 // .. 403 // vector.body: 404 // .. 405 // 406 static bool MatchElemCountLoopSetup(Loop *L, Instruction *Shuffle, 407 Value *NumElements) { 408 using namespace PatternMatch; 409 Instruction *Insert = nullptr; 410 411 if (!match(Shuffle, 412 m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_ZeroMask()))) 413 return false; 414 415 // Insert the limit into a vector. 416 Instruction *BECount = nullptr; 417 if (!match(Insert, 418 m_InsertElement(m_Undef(), m_Instruction(BECount), m_Zero()))) 419 return false; 420 421 // The limit calculation, backedge count. 422 Value *TripCount = nullptr; 423 if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 424 return false; 425 426 if (TripCount != NumElements || !L->isLoopInvariant(BECount)) 427 return false; 428 429 return true; 430 } 431 432 bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP) { 433 using namespace PatternMatch; 434 const SCEV *TripCountSE = SE->getSCEV(TCP.TripCount); 435 ConstantInt *VF = ConstantInt::get( 436 cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements()); 437 438 if (VF->equalsInt(1)) 439 return false; 440 441 CmpInst::Predicate Pred; 442 if (!match(TCP.Predicate, m_ICmp(Pred, m_Instruction(TCP.Induction), 443 m_Instruction(TCP.Shuffle))) || 444 Pred != ICmpInst::ICMP_ULE) 445 return false; 446 447 LLVM_DEBUG(dbgs() << "Computing number of elements for vector trip count: "; 448 TCP.TripCount->dump()); 449 450 // Otherwise, continue and try to pattern match the vector iteration 451 // count expression 452 auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr * { 453 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 454 if (Const->getAPInt() != -VF->getValue()) 455 return nullptr; 456 } else 457 return nullptr; 458 return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 459 }; 460 461 auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr * { 462 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 463 if (Const->getValue() != VF) 464 return nullptr; 465 } else 466 return nullptr; 467 return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 468 }; 469 470 auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV * { 471 if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 472 if (Const->getValue() != VF) 473 return nullptr; 474 } else 475 return nullptr; 476 477 if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 478 if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 479 if (Const->getAPInt() != (VF->getValue() - 1)) 480 return nullptr; 481 } else 482 return nullptr; 483 484 return RoundUp->getOperand(1); 485 } 486 return nullptr; 487 }; 488 489 // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 490 // determine the numbers of elements instead? Looks like this is what is used 491 // for delinearization, but I'm not sure if it can be applied to the 492 // vectorized form - at least not without a bit more work than I feel 493 // comfortable with. 494 495 // Search for Elems in the following SCEV: 496 // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 497 const SCEV *Elems = nullptr; 498 if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 499 if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 500 if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 501 if (auto *Mul = VisitAdd(Add)) 502 if (auto *Div = VisitMul(Mul)) 503 if (auto *Res = VisitDiv(Div)) 504 Elems = Res; 505 506 if (!Elems) 507 return false; 508 509 Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 510 if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 511 return false; 512 513 auto DL = L->getHeader()->getModule()->getDataLayout(); 514 SCEVExpander Expander(*SE, DL, "elements"); 515 TCP.NumElements = Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 516 517 if (!MatchElemCountLoopSetup(L, TCP.Shuffle, TCP.NumElements)) 518 return false; 519 520 return true; 521 } 522 523 // Look through the exit block to see whether there's a duplicate predicate 524 // instruction. This can happen when we need to perform a select on values 525 // from the last and previous iteration. Instead of doing a straight 526 // replacement of that predicate with the vctp, clone the vctp and place it 527 // in the block. This means that the VPR doesn't have to be live into the 528 // exit block which should make it easier to convert this loop into a proper 529 // tail predicated loop. 530 static bool Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 531 SetVector<Instruction*> &MaybeDead, Loop *L) { 532 BasicBlock *Exit = L->getUniqueExitBlock(); 533 if (!Exit) { 534 LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n"); 535 return false; 536 } 537 538 bool ClonedVCTPInExitBlock = false; 539 540 for (auto &Pair : NewPredicates) { 541 Instruction *OldPred = Pair.first; 542 Instruction *NewPred = Pair.second; 543 544 for (auto &I : *Exit) { 545 if (I.isSameOperationAs(OldPred)) { 546 Instruction *PredClone = NewPred->clone(); 547 PredClone->insertBefore(&I); 548 I.replaceAllUsesWith(PredClone); 549 MaybeDead.insert(&I); 550 ClonedVCTPInExitBlock = true; 551 LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump(); 552 dbgs() << "ARM TP: with: "; PredClone->dump()); 553 break; 554 } 555 } 556 } 557 558 // Drop references and add operands to check for dead. 559 SmallPtrSet<Instruction*, 4> Dead; 560 while (!MaybeDead.empty()) { 561 auto *I = MaybeDead.front(); 562 MaybeDead.remove(I); 563 if (I->hasNUsesOrMore(1)) 564 continue; 565 566 for (auto &U : I->operands()) { 567 if (auto *OpI = dyn_cast<Instruction>(U)) 568 MaybeDead.insert(OpI); 569 } 570 I->dropAllReferences(); 571 Dead.insert(I); 572 } 573 574 for (auto *I : Dead) { 575 LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump()); 576 I->eraseFromParent(); 577 } 578 579 for (auto I : L->blocks()) 580 DeleteDeadPHIs(I); 581 582 return ClonedVCTPInExitBlock; 583 } 584 585 void MVETailPredication::InsertVCTPIntrinsic(TripCountPattern &TCP, 586 DenseMap<Instruction*, Instruction*> &NewPredicates) { 587 IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 588 Module *M = L->getHeader()->getModule(); 589 Type *Ty = IntegerType::get(M->getContext(), 32); 590 591 // Insert a phi to count the number of elements processed by the loop. 592 PHINode *Processed = Builder.CreatePHI(Ty, 2); 593 Processed->addIncoming(TCP.NumElements, L->getLoopPreheader()); 594 595 // Insert the intrinsic to represent the effect of tail predication. 596 Builder.SetInsertPoint(cast<Instruction>(TCP.Predicate)); 597 ConstantInt *Factor = 598 ConstantInt::get(cast<IntegerType>(Ty), TCP.VecTy->getNumElements()); 599 600 Intrinsic::ID VCTPID; 601 switch (TCP.VecTy->getNumElements()) { 602 default: 603 llvm_unreachable("unexpected number of lanes"); 604 case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 605 case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 606 case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 607 608 // FIXME: vctp64 currently not supported because the predicate 609 // vector wants to be <2 x i1>, but v2i1 is not a legal MVE 610 // type, so problems happen at isel time. 611 // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics 612 // purposes, but takes a v4i1 instead of a v2i1. 613 } 614 Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 615 Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 616 TCP.Predicate->replaceAllUsesWith(TailPredicate); 617 NewPredicates[TCP.Predicate] = cast<Instruction>(TailPredicate); 618 619 // Add the incoming value to the new phi. 620 // TODO: This add likely already exists in the loop. 621 Value *Remaining = Builder.CreateSub(Processed, Factor); 622 Processed->addIncoming(Remaining, L->getLoopLatch()); 623 LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 624 << *Processed << "\n" 625 << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n"); 626 } 627 628 bool MVETailPredication::TryConvert(Value *TripCount) { 629 if (!IsPredicatedVectorLoop()) { 630 LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n"); 631 return false; 632 } 633 634 LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 635 636 // Walk through the masked intrinsics and try to find whether the predicate 637 // operand is generated from an induction variable. 638 SetVector<Instruction*> Predicates; 639 DenseMap<Instruction*, Instruction*> NewPredicates; 640 641 for (auto *I : MaskedInsts) { 642 Intrinsic::ID ID = I->getIntrinsicID(); 643 unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 644 auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 645 if (!Predicate || Predicates.count(Predicate)) 646 continue; 647 648 TripCountPattern TCP(Predicate, TripCount, getVectorType(I)); 649 650 if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP))) 651 continue; 652 653 if (!isTailPredicate(TCP)) { 654 LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n"); 655 continue; 656 } 657 658 LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n"); 659 Predicates.insert(Predicate); 660 InsertVCTPIntrinsic(TCP, NewPredicates); 661 } 662 663 if (!NewPredicates.size()) 664 return false; 665 666 // Now clean up. 667 ClonedVCTPInExitBlock = Cleanup(NewPredicates, Predicates, L); 668 return true; 669 } 670 671 Pass *llvm::createMVETailPredicationPass() { 672 return new MVETailPredication(); 673 } 674 675 char MVETailPredication::ID = 0; 676 677 INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 678 INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 679