1 //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11 /// branches to help accelerate DSP applications. These two extensions can be 12 /// combined to provide implicit vector predication within a low-overhead loop. 13 /// The HardwareLoops pass inserts intrinsics identifying loops that the 14 /// backend will attempt to convert into a low-overhead loop. The vectorizer is 15 /// responsible for generating a vectorized loop in which the lanes are 16 /// predicated upon the iteration counter. This pass looks at these predicated 17 /// vector loops, that are targets for low-overhead loops, and prepares it for 18 /// code generation. Once the vectorizer has produced a masked loop, there's a 19 /// couple of final forms: 20 /// - A tail-predicated loop, with implicit predication. 21 /// - A loop containing multiple VCPT instructions, predicating multiple VPT 22 /// blocks of instructions operating on different vector types. 23 24 #include "llvm/Analysis/LoopInfo.h" 25 #include "llvm/Analysis/LoopPass.h" 26 #include "llvm/Analysis/ScalarEvolution.h" 27 #include "llvm/Analysis/ScalarEvolutionExpander.h" 28 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 29 #include "llvm/Analysis/TargetTransformInfo.h" 30 #include "llvm/CodeGen/TargetPassConfig.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IRBuilder.h" 33 #include "llvm/IR/PatternMatch.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 36 #include "ARM.h" 37 #include "ARMSubtarget.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "mve-tail-predication" 42 #define DESC "Transform predicated vector loops to use MVE tail predication" 43 44 static cl::opt<bool> 45 DisableTailPredication("disable-mve-tail-predication", cl::Hidden, 46 cl::init(true), 47 cl::desc("Disable MVE Tail Predication")); 48 namespace { 49 50 class MVETailPredication : public LoopPass { 51 SmallVector<IntrinsicInst*, 4> MaskedInsts; 52 Loop *L = nullptr; 53 ScalarEvolution *SE = nullptr; 54 TargetTransformInfo *TTI = nullptr; 55 56 public: 57 static char ID; 58 59 MVETailPredication() : LoopPass(ID) { } 60 61 void getAnalysisUsage(AnalysisUsage &AU) const override { 62 AU.addRequired<ScalarEvolutionWrapperPass>(); 63 AU.addRequired<LoopInfoWrapperPass>(); 64 AU.addRequired<TargetPassConfig>(); 65 AU.addRequired<TargetTransformInfoWrapperPass>(); 66 AU.addPreserved<LoopInfoWrapperPass>(); 67 AU.setPreservesCFG(); 68 } 69 70 bool runOnLoop(Loop *L, LPPassManager&) override; 71 72 private: 73 74 /// Perform the relevant checks on the loop and convert if possible. 75 bool TryConvert(Value *TripCount); 76 77 /// Return whether this is a vectorized loop, that contains masked 78 /// load/stores. 79 bool IsPredicatedVectorLoop(); 80 81 /// Compute a value for the total number of elements that the predicated 82 /// loop will process. 83 Value *ComputeElements(Value *TripCount, VectorType *VecTy); 84 85 /// Is the icmp that generates an i1 vector, based upon a loop counter 86 /// and a limit that is defined outside the loop. 87 bool isTailPredicate(Instruction *Predicate, Value *NumElements); 88 }; 89 90 } // end namespace 91 92 static bool IsDecrement(Instruction &I) { 93 auto *Call = dyn_cast<IntrinsicInst>(&I); 94 if (!Call) 95 return false; 96 97 Intrinsic::ID ID = Call->getIntrinsicID(); 98 return ID == Intrinsic::loop_decrement_reg; 99 } 100 101 static bool IsMasked(Instruction *I) { 102 auto *Call = dyn_cast<IntrinsicInst>(I); 103 if (!Call) 104 return false; 105 106 Intrinsic::ID ID = Call->getIntrinsicID(); 107 // TODO: Support gather/scatter expand/compress operations. 108 return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 109 } 110 111 bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 112 if (skipLoop(L) || DisableTailPredication) 113 return false; 114 115 Function &F = *L->getHeader()->getParent(); 116 auto &TPC = getAnalysis<TargetPassConfig>(); 117 auto &TM = TPC.getTM<TargetMachine>(); 118 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 119 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 120 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 121 this->L = L; 122 123 // The MVE and LOB extensions are combined to enable tail-predication, but 124 // there's nothing preventing us from generating VCTP instructions for v8.1m. 125 if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 126 LLVM_DEBUG(dbgs() << "TP: Not a v8.1m.main+mve target.\n"); 127 return false; 128 } 129 130 BasicBlock *Preheader = L->getLoopPreheader(); 131 if (!Preheader) 132 return false; 133 134 auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 135 for (auto &I : *BB) { 136 auto *Call = dyn_cast<IntrinsicInst>(&I); 137 if (!Call) 138 continue; 139 140 Intrinsic::ID ID = Call->getIntrinsicID(); 141 if (ID == Intrinsic::set_loop_iterations || 142 ID == Intrinsic::test_set_loop_iterations) 143 return cast<IntrinsicInst>(&I); 144 } 145 return nullptr; 146 }; 147 148 // Look for the hardware loop intrinsic that sets the iteration count. 149 IntrinsicInst *Setup = FindLoopIterations(Preheader); 150 151 // The test.set iteration could live in the pre- preheader. 152 if (!Setup) { 153 if (!Preheader->getSinglePredecessor()) 154 return false; 155 Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 156 if (!Setup) 157 return false; 158 } 159 160 // Search for the hardware loop intrinic that decrements the loop counter. 161 IntrinsicInst *Decrement = nullptr; 162 for (auto *BB : L->getBlocks()) { 163 for (auto &I : *BB) { 164 if (IsDecrement(I)) { 165 Decrement = cast<IntrinsicInst>(&I); 166 break; 167 } 168 } 169 } 170 171 if (!Decrement) 172 return false; 173 174 LLVM_DEBUG(dbgs() << "TP: Running on Loop: " << *L 175 << *Setup << "\n" 176 << *Decrement << "\n"); 177 bool Changed = TryConvert(Setup->getArgOperand(0)); 178 return Changed; 179 } 180 181 bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { 182 // Look for the following: 183 184 // %trip.count.minus.1 = add i32 %N, -1 185 // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 186 // i32 %trip.count.minus.1, i32 0 187 // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 188 // <4 x i32> undef, 189 // <4 x i32> zeroinitializer 190 // ... 191 // ... 192 // %index = phi i32 193 // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 194 // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 195 // <4 x i32> undef, 196 // <4 x i32> zeroinitializer 197 // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 198 // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 199 200 // And return whether V == %pred. 201 202 using namespace PatternMatch; 203 204 CmpInst::Predicate Pred; 205 Instruction *Shuffle = nullptr; 206 Instruction *Induction = nullptr; 207 208 // The vector icmp 209 if (!match(I, m_ICmp(Pred, m_Instruction(Induction), 210 m_Instruction(Shuffle))) || 211 Pred != ICmpInst::ICMP_ULE || !L->isLoopInvariant(Shuffle)) 212 return false; 213 214 // First find the stuff outside the loop which is setting up the limit 215 // vector.... 216 // The invariant shuffle that broadcast the limit into a vector. 217 Instruction *Insert = nullptr; 218 if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 219 m_Zero()))) 220 return false; 221 222 // Insert the limit into a vector. 223 Instruction *BECount = nullptr; 224 if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount), 225 m_Zero()))) 226 return false; 227 228 // The limit calculation, backedge count. 229 Value *TripCount = nullptr; 230 if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 231 return false; 232 233 if (TripCount != NumElements) 234 return false; 235 236 // Now back to searching inside the loop body... 237 // Find the add with takes the index iv and adds a constant vector to it. 238 Instruction *BroadcastSplat = nullptr; 239 Constant *Const = nullptr; 240 if (!match(Induction, m_Add(m_Instruction(BroadcastSplat), 241 m_Constant(Const)))) 242 return false; 243 244 // Check that we're adding <0, 1, 2, 3... 245 if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 246 for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 247 if (CDS->getElementAsInteger(i) != i) 248 return false; 249 } 250 } else 251 return false; 252 253 // The shuffle which broadcasts the index iv into a vector. 254 if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(), 255 m_Zero()))) 256 return false; 257 258 // The insert element which initialises a vector with the index iv. 259 Instruction *IV = nullptr; 260 if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 261 return false; 262 263 // The index iv. 264 auto *Phi = dyn_cast<PHINode>(IV); 265 if (!Phi) 266 return false; 267 268 // TODO: Don't think we need to check the entry value. 269 Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 270 if (!match(OnEntry, m_Zero())) 271 return false; 272 273 Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 274 unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 275 276 Instruction *LHS = nullptr; 277 if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 278 return false; 279 280 return LHS == Phi; 281 } 282 283 static VectorType* getVectorType(IntrinsicInst *I) { 284 unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 285 auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 286 return cast<VectorType>(PtrTy->getElementType()); 287 } 288 289 bool MVETailPredication::IsPredicatedVectorLoop() { 290 // Check that the loop contains at least one masked load/store intrinsic. 291 // We only support 'normal' vector instructions - other than masked 292 // load/stores. 293 for (auto *BB : L->getBlocks()) { 294 for (auto &I : *BB) { 295 if (IsMasked(&I)) { 296 VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 297 unsigned Lanes = VecTy->getNumElements(); 298 unsigned ElementWidth = VecTy->getScalarSizeInBits(); 299 // MVE vectors are 128-bit, but don't support 128 x i1. 300 // TODO: Can we support vectors larger than 128-bits? 301 unsigned MaxWidth = TTI->getRegisterBitWidth(true); 302 if (Lanes * ElementWidth != MaxWidth || Lanes == MaxWidth) 303 return false; 304 MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 305 } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 306 for (auto &U : Int->args()) { 307 if (isa<VectorType>(U->getType())) 308 return false; 309 } 310 } 311 } 312 } 313 314 return !MaskedInsts.empty(); 315 } 316 317 Value* MVETailPredication::ComputeElements(Value *TripCount, 318 VectorType *VecTy) { 319 const SCEV *TripCountSE = SE->getSCEV(TripCount); 320 ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()), 321 VecTy->getNumElements()); 322 323 if (VF->equalsInt(1)) 324 return nullptr; 325 326 // TODO: Support constant trip counts. 327 auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { 328 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 329 if (Const->getAPInt() != -VF->getValue()) 330 return nullptr; 331 } else 332 return nullptr; 333 return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 334 }; 335 336 auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { 337 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 338 if (Const->getValue() != VF) 339 return nullptr; 340 } else 341 return nullptr; 342 return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 343 }; 344 345 auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { 346 if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 347 if (Const->getValue() != VF) 348 return nullptr; 349 } else 350 return nullptr; 351 352 if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 353 if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 354 if (Const->getAPInt() != (VF->getValue() - 1)) 355 return nullptr; 356 } else 357 return nullptr; 358 359 return RoundUp->getOperand(1); 360 } 361 return nullptr; 362 }; 363 364 // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 365 // determine the numbers of elements instead? Looks like this is what is used 366 // for delinearization, but I'm not sure if it can be applied to the 367 // vectorized form - at least not without a bit more work than I feel 368 // comfortable with. 369 370 // Search for Elems in the following SCEV: 371 // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 372 const SCEV *Elems = nullptr; 373 if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 374 if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 375 if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 376 if (auto *Mul = VisitAdd(Add)) 377 if (auto *Div = VisitMul(Mul)) 378 if (auto *Res = VisitDiv(Div)) 379 Elems = Res; 380 381 if (!Elems) 382 return nullptr; 383 384 Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 385 if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 386 return nullptr; 387 388 auto DL = L->getHeader()->getModule()->getDataLayout(); 389 SCEVExpander Expander(*SE, DL, "elements"); 390 return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 391 } 392 393 // Look through the exit block to see whether there's a duplicate predicate 394 // instruction. This can happen when we need to perform a select on values 395 // from the last and previous iteration. Instead of doing a straight 396 // replacement of that predicate with the vctp, clone the vctp and place it 397 // in the block. This means that the VPR doesn't have to be live into the 398 // exit block which should make it easier to convert this loop into a proper 399 // tail predicated loop. 400 static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 401 SetVector<Instruction*> &MaybeDead, Loop *L) { 402 if (BasicBlock *Exit = L->getUniqueExitBlock()) { 403 for (auto &Pair : NewPredicates) { 404 Instruction *OldPred = Pair.first; 405 Instruction *NewPred = Pair.second; 406 407 for (auto &I : *Exit) { 408 if (I.isSameOperationAs(OldPred)) { 409 Instruction *PredClone = NewPred->clone(); 410 PredClone->insertBefore(&I); 411 I.replaceAllUsesWith(PredClone); 412 MaybeDead.insert(&I); 413 break; 414 } 415 } 416 } 417 } 418 419 // Drop references and add operands to check for dead. 420 SmallPtrSet<Instruction*, 4> Dead; 421 while (!MaybeDead.empty()) { 422 auto *I = MaybeDead.front(); 423 MaybeDead.remove(I); 424 if (I->hasNUsesOrMore(1)) 425 continue; 426 427 for (auto &U : I->operands()) { 428 if (auto *OpI = dyn_cast<Instruction>(U)) 429 MaybeDead.insert(OpI); 430 } 431 I->dropAllReferences(); 432 Dead.insert(I); 433 } 434 435 for (auto *I : Dead) 436 I->eraseFromParent(); 437 438 for (auto I : L->blocks()) 439 DeleteDeadPHIs(I); 440 } 441 442 bool MVETailPredication::TryConvert(Value *TripCount) { 443 if (!IsPredicatedVectorLoop()) 444 return false; 445 446 LLVM_DEBUG(dbgs() << "TP: Found predicated vector loop.\n"); 447 448 // Walk through the masked intrinsics and try to find whether the predicate 449 // operand is generated from an induction variable. 450 Module *M = L->getHeader()->getModule(); 451 Type *Ty = IntegerType::get(M->getContext(), 32); 452 SetVector<Instruction*> Predicates; 453 DenseMap<Instruction*, Instruction*> NewPredicates; 454 455 for (auto *I : MaskedInsts) { 456 Intrinsic::ID ID = I->getIntrinsicID(); 457 unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 458 auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 459 if (!Predicate || Predicates.count(Predicate)) 460 continue; 461 462 VectorType *VecTy = getVectorType(I); 463 Value *NumElements = ComputeElements(TripCount, VecTy); 464 if (!NumElements) 465 continue; 466 467 if (!isTailPredicate(Predicate, NumElements)) { 468 LLVM_DEBUG(dbgs() << "TP: Not tail predicate: " << *Predicate << "\n"); 469 continue; 470 } 471 472 LLVM_DEBUG(dbgs() << "TP: Found tail predicate: " << *Predicate << "\n"); 473 Predicates.insert(Predicate); 474 475 // Insert a phi to count the number of elements processed by the loop. 476 IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 477 PHINode *Processed = Builder.CreatePHI(Ty, 2); 478 Processed->addIncoming(NumElements, L->getLoopPreheader()); 479 480 // Insert the intrinsic to represent the effect of tail predication. 481 Builder.SetInsertPoint(cast<Instruction>(Predicate)); 482 ConstantInt *Factor = 483 ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements()); 484 Intrinsic::ID VCTPID; 485 switch (VecTy->getNumElements()) { 486 default: 487 llvm_unreachable("unexpected number of lanes"); 488 case 2: VCTPID = Intrinsic::arm_vctp64; break; 489 case 4: VCTPID = Intrinsic::arm_vctp32; break; 490 case 8: VCTPID = Intrinsic::arm_vctp16; break; 491 case 16: VCTPID = Intrinsic::arm_vctp8; break; 492 } 493 Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 494 Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 495 Predicate->replaceAllUsesWith(TailPredicate); 496 NewPredicates[Predicate] = cast<Instruction>(TailPredicate); 497 498 // Add the incoming value to the new phi. 499 // TODO: This add likely already exists in the loop. 500 Value *Remaining = Builder.CreateSub(Processed, Factor); 501 Processed->addIncoming(Remaining, L->getLoopLatch()); 502 LLVM_DEBUG(dbgs() << "TP: Insert processed elements phi: " 503 << *Processed << "\n" 504 << "TP: Inserted VCTP: " << *TailPredicate << "\n"); 505 } 506 507 // Now clean up. 508 Cleanup(NewPredicates, Predicates, L); 509 return true; 510 } 511 512 Pass *llvm::createMVETailPredicationPass() { 513 return new MVETailPredication(); 514 } 515 516 char MVETailPredication::ID = 0; 517 518 INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 519 INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 520