1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISC-V intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/InstSimplifyFolder.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/ValueTracking.h" 19 #include "llvm/Analysis/VectorUtils.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/IntrinsicInst.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/Transforms/Utils/Local.h" 26 #include <optional> 27 28 using namespace llvm; 29 using namespace PatternMatch; 30 31 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 32 33 namespace { 34 35 class RISCVGatherScatterLowering : public FunctionPass { 36 const RISCVSubtarget *ST = nullptr; 37 const RISCVTargetLowering *TLI = nullptr; 38 LoopInfo *LI = nullptr; 39 const DataLayout *DL = nullptr; 40 41 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 42 43 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is 44 // used by multiple gathers/scatters, this allow us to reuse the scalar 45 // instructions we created for the first gather/scatter for the others. 46 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; 47 48 public: 49 static char ID; // Pass identification, replacement for typeid 50 51 RISCVGatherScatterLowering() : FunctionPass(ID) {} 52 53 bool runOnFunction(Function &F) override; 54 55 void getAnalysisUsage(AnalysisUsage &AU) const override { 56 AU.setPreservesCFG(); 57 AU.addRequired<TargetPassConfig>(); 58 AU.addRequired<LoopInfoWrapperPass>(); 59 } 60 61 StringRef getPassName() const override { 62 return "RISC-V gather/scatter lowering"; 63 } 64 65 private: 66 bool tryCreateStridedLoadStore(IntrinsicInst *II); 67 68 std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, 69 IRBuilderBase &Builder); 70 71 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 72 PHINode *&BasePtr, BinaryOperator *&Inc, 73 IRBuilderBase &Builder); 74 }; 75 76 } // end anonymous namespace 77 78 char RISCVGatherScatterLowering::ID = 0; 79 80 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 81 "RISC-V gather/scatter lowering pass", false, false) 82 83 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 84 return new RISCVGatherScatterLowering(); 85 } 86 87 // TODO: Should we consider the mask when looking for a stride? 88 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 89 if (!isa<FixedVectorType>(StartC->getType())) 90 return std::make_pair(nullptr, nullptr); 91 92 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 93 94 // Check that the start value is a strided constant. 95 auto *StartVal = 96 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 97 if (!StartVal) 98 return std::make_pair(nullptr, nullptr); 99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 100 ConstantInt *Prev = StartVal; 101 for (unsigned i = 1; i != NumElts; ++i) { 102 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 103 if (!C) 104 return std::make_pair(nullptr, nullptr); 105 106 APInt LocalStride = C->getValue() - Prev->getValue(); 107 if (i == 1) 108 StrideVal = LocalStride; 109 else if (StrideVal != LocalStride) 110 return std::make_pair(nullptr, nullptr); 111 112 Prev = C; 113 } 114 115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 116 117 return std::make_pair(StartVal, Stride); 118 } 119 120 static std::pair<Value *, Value *> matchStridedStart(Value *Start, 121 IRBuilderBase &Builder) { 122 // Base case, start is a strided constant. 123 auto *StartC = dyn_cast<Constant>(Start); 124 if (StartC) 125 return matchStridedConstant(StartC); 126 127 // Base case, start is a stepvector 128 if (match(Start, m_Intrinsic<Intrinsic::stepvector>())) { 129 auto *Ty = Start->getType()->getScalarType(); 130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); 131 } 132 133 // Not a constant, maybe it's a strided constant with a splat added or 134 // multipled. 135 auto *BO = dyn_cast<BinaryOperator>(Start); 136 if (!BO || (BO->getOpcode() != Instruction::Add && 137 BO->getOpcode() != Instruction::Or && 138 BO->getOpcode() != Instruction::Shl && 139 BO->getOpcode() != Instruction::Mul)) 140 return std::make_pair(nullptr, nullptr); 141 142 if (BO->getOpcode() == Instruction::Or && 143 !cast<PossiblyDisjointInst>(BO)->isDisjoint()) 144 return std::make_pair(nullptr, nullptr); 145 146 // Look for an operand that is splatted. 147 unsigned OtherIndex = 0; 148 Value *Splat = getSplatValue(BO->getOperand(1)); 149 if (!Splat && Instruction::isCommutative(BO->getOpcode())) { 150 Splat = getSplatValue(BO->getOperand(0)); 151 OtherIndex = 1; 152 } 153 if (!Splat) 154 return std::make_pair(nullptr, nullptr); 155 156 Value *Stride; 157 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 158 Builder); 159 if (!Start) 160 return std::make_pair(nullptr, nullptr); 161 162 Builder.SetInsertPoint(BO); 163 Builder.SetCurrentDebugLocation(DebugLoc()); 164 // Add the splat value to the start or multiply the start and stride by the 165 // splat. 166 switch (BO->getOpcode()) { 167 default: 168 llvm_unreachable("Unexpected opcode"); 169 case Instruction::Or: 170 // TODO: We'd be better off creating disjoint or here, but we don't yet 171 // have an IRBuilder API for that. 172 [[fallthrough]]; 173 case Instruction::Add: 174 Start = Builder.CreateAdd(Start, Splat); 175 break; 176 case Instruction::Mul: 177 Start = Builder.CreateMul(Start, Splat); 178 Stride = Builder.CreateMul(Stride, Splat); 179 break; 180 case Instruction::Shl: 181 Start = Builder.CreateShl(Start, Splat); 182 Stride = Builder.CreateShl(Stride, Splat); 183 break; 184 } 185 186 return std::make_pair(Start, Stride); 187 } 188 189 // Recursively, walk about the use-def chain until we find a Phi with a strided 190 // start value. Build and update a scalar recurrence as we unwind the recursion. 191 // We also update the Stride as we unwind. Our goal is to move all of the 192 // arithmetic out of the loop. 193 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 194 Value *&Stride, 195 PHINode *&BasePtr, 196 BinaryOperator *&Inc, 197 IRBuilderBase &Builder) { 198 // Our base case is a Phi. 199 if (auto *Phi = dyn_cast<PHINode>(Index)) { 200 // A phi node we want to perform this function on should be from the 201 // loop header. 202 if (Phi->getParent() != L->getHeader()) 203 return false; 204 205 Value *Step, *Start; 206 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 207 Inc->getOpcode() != Instruction::Add) 208 return false; 209 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 210 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 211 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 212 "Expected one operand of phi to be Inc"); 213 214 // Step should be a splat. 215 Step = getSplatValue(Step); 216 if (!Step) 217 return false; 218 219 std::tie(Start, Stride) = matchStridedStart(Start, Builder); 220 if (!Start) 221 return false; 222 assert(Stride != nullptr); 223 224 // Build scalar phi and increment. 225 BasePtr = 226 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator()); 227 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 228 Inc->getIterator()); 229 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 230 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 231 232 // Note that this Phi might be eligible for removal. 233 MaybeDeadPHIs.push_back(Phi); 234 return true; 235 } 236 237 // Otherwise look for binary operator. 238 auto *BO = dyn_cast<BinaryOperator>(Index); 239 if (!BO) 240 return false; 241 242 switch (BO->getOpcode()) { 243 default: 244 return false; 245 case Instruction::Or: 246 // We need to be able to treat Or as Add. 247 if (!cast<PossiblyDisjointInst>(BO)->isDisjoint()) 248 return false; 249 break; 250 case Instruction::Add: 251 break; 252 case Instruction::Shl: 253 break; 254 case Instruction::Mul: 255 break; 256 } 257 258 // We should have one operand in the loop and one splat. 259 Value *OtherOp; 260 if (isa<Instruction>(BO->getOperand(0)) && 261 L->contains(cast<Instruction>(BO->getOperand(0)))) { 262 Index = cast<Instruction>(BO->getOperand(0)); 263 OtherOp = BO->getOperand(1); 264 } else if (isa<Instruction>(BO->getOperand(1)) && 265 L->contains(cast<Instruction>(BO->getOperand(1))) && 266 Instruction::isCommutative(BO->getOpcode())) { 267 Index = cast<Instruction>(BO->getOperand(1)); 268 OtherOp = BO->getOperand(0); 269 } else { 270 return false; 271 } 272 273 // Make sure other op is loop invariant. 274 if (!L->isLoopInvariant(OtherOp)) 275 return false; 276 277 // Make sure we have a splat. 278 Value *SplatOp = getSplatValue(OtherOp); 279 if (!SplatOp) 280 return false; 281 282 // Recurse up the use-def chain. 283 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 284 return false; 285 286 // Locate the Step and Start values from the recurrence. 287 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 288 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 289 Value *Step = Inc->getOperand(StepIndex); 290 Value *Start = BasePtr->getOperand(StartBlock); 291 292 // We need to adjust the start value in the preheader. 293 Builder.SetInsertPoint( 294 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 295 Builder.SetCurrentDebugLocation(DebugLoc()); 296 297 // TODO: Share this switch with matchStridedStart? 298 switch (BO->getOpcode()) { 299 default: 300 llvm_unreachable("Unexpected opcode!"); 301 case Instruction::Add: 302 case Instruction::Or: { 303 // An add only affects the start value. It's ok to do this for Or because 304 // we already checked that there are no common set bits. 305 Start = Builder.CreateAdd(Start, SplatOp, "start"); 306 break; 307 } 308 case Instruction::Mul: { 309 Start = Builder.CreateMul(Start, SplatOp, "start"); 310 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 311 break; 312 } 313 case Instruction::Shl: { 314 Start = Builder.CreateShl(Start, SplatOp, "start"); 315 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 316 break; 317 } 318 } 319 320 // If the Step was defined inside the loop, adjust it before its definition 321 // instead of in the preheader. 322 if (auto *StepI = dyn_cast<Instruction>(Step); StepI && L->contains(StepI)) 323 Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef()); 324 325 switch (BO->getOpcode()) { 326 default: 327 break; 328 case Instruction::Mul: 329 Step = Builder.CreateMul(Step, SplatOp, "step"); 330 break; 331 case Instruction::Shl: 332 Step = Builder.CreateShl(Step, SplatOp, "step"); 333 break; 334 } 335 336 Inc->setOperand(StepIndex, Step); 337 BasePtr->setIncomingValue(StartBlock, Start); 338 return true; 339 } 340 341 std::pair<Value *, Value *> 342 RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, 343 IRBuilderBase &Builder) { 344 345 // A gather/scatter of a splat is a zero strided load/store. 346 if (auto *BasePtr = getSplatValue(Ptr)) { 347 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 348 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0)); 349 } 350 351 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 352 if (!GEP) 353 return std::make_pair(nullptr, nullptr); 354 355 auto I = StridedAddrs.find(GEP); 356 if (I != StridedAddrs.end()) 357 return I->second; 358 359 SmallVector<Value *, 2> Ops(GEP->operands()); 360 361 // If the base pointer is a vector, check if it's strided. 362 Value *Base = GEP->getPointerOperand(); 363 if (auto *BaseInst = dyn_cast<Instruction>(Base); 364 BaseInst && BaseInst->getType()->isVectorTy()) { 365 // If GEP's offset is scalar then we can add it to the base pointer's base. 366 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; 367 if (all_of(GEP->indices(), IsScalar)) { 368 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder); 369 if (BaseBase) { 370 Builder.SetInsertPoint(GEP); 371 SmallVector<Value *> Indices(GEP->indices()); 372 Value *OffsetBase = 373 Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices, 374 GEP->getName() + "offset", GEP->isInBounds()); 375 return {OffsetBase, Stride}; 376 } 377 } 378 } 379 380 // Base pointer needs to be a scalar. 381 Value *ScalarBase = Base; 382 if (ScalarBase->getType()->isVectorTy()) { 383 ScalarBase = getSplatValue(ScalarBase); 384 if (!ScalarBase) 385 return std::make_pair(nullptr, nullptr); 386 } 387 388 std::optional<unsigned> VecOperand; 389 unsigned TypeScale = 0; 390 391 // Look for a vector operand and scale. 392 gep_type_iterator GTI = gep_type_begin(GEP); 393 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 394 if (!Ops[i]->getType()->isVectorTy()) 395 continue; 396 397 if (VecOperand) 398 return std::make_pair(nullptr, nullptr); 399 400 VecOperand = i; 401 402 TypeSize TS = GTI.getSequentialElementStride(*DL); 403 if (TS.isScalable()) 404 return std::make_pair(nullptr, nullptr); 405 406 TypeScale = TS.getFixedValue(); 407 } 408 409 // We need to find a vector index to simplify. 410 if (!VecOperand) 411 return std::make_pair(nullptr, nullptr); 412 413 // We can't extract the stride if the arithmetic is done at a different size 414 // than the pointer type. Adding the stride later may not wrap correctly. 415 // Technically we could handle wider indices, but I don't expect that in 416 // practice. Handle one special case here - constants. This simplifies 417 // writing test cases. 418 Value *VecIndex = Ops[*VecOperand]; 419 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 420 if (VecIndex->getType() != VecIntPtrTy) { 421 auto *VecIndexC = dyn_cast<Constant>(VecIndex); 422 if (!VecIndexC) 423 return std::make_pair(nullptr, nullptr); 424 if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) 425 VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy); 426 else 427 VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy); 428 } 429 430 // Handle the non-recursive case. This is what we see if the vectorizer 431 // decides to use a scalar IV + vid on demand instead of a vector IV. 432 auto [Start, Stride] = matchStridedStart(VecIndex, Builder); 433 if (Start) { 434 assert(Stride); 435 Builder.SetInsertPoint(GEP); 436 437 // Replace the vector index with the scalar start and build a scalar GEP. 438 Ops[*VecOperand] = Start; 439 Type *SourceTy = GEP->getSourceElementType(); 440 Value *BasePtr = 441 Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 442 443 // Convert stride to pointer size if needed. 444 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 445 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 446 447 // Scale the stride by the size of the indexed type. 448 if (TypeScale != 1) 449 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 450 451 auto P = std::make_pair(BasePtr, Stride); 452 StridedAddrs[GEP] = P; 453 return P; 454 } 455 456 // Make sure we're in a loop and that has a pre-header and a single latch. 457 Loop *L = LI->getLoopFor(GEP->getParent()); 458 if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) 459 return std::make_pair(nullptr, nullptr); 460 461 BinaryOperator *Inc; 462 PHINode *BasePhi; 463 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 464 return std::make_pair(nullptr, nullptr); 465 466 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 467 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 468 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 469 "Expected one operand of phi to be Inc"); 470 471 Builder.SetInsertPoint(GEP); 472 473 // Replace the vector index with the scalar phi and build a scalar GEP. 474 Ops[*VecOperand] = BasePhi; 475 Type *SourceTy = GEP->getSourceElementType(); 476 Value *BasePtr = 477 Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 478 479 // Final adjustments to stride should go in the start block. 480 Builder.SetInsertPoint( 481 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 482 483 // Convert stride to pointer size if needed. 484 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 485 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 486 487 // Scale the stride by the size of the indexed type. 488 if (TypeScale != 1) 489 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 490 491 auto P = std::make_pair(BasePtr, Stride); 492 StridedAddrs[GEP] = P; 493 return P; 494 } 495 496 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { 497 VectorType *DataType; 498 Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr; 499 MaybeAlign MA; 500 switch (II->getIntrinsicID()) { 501 case Intrinsic::masked_gather: 502 DataType = cast<VectorType>(II->getType()); 503 Ptr = II->getArgOperand(0); 504 MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue(); 505 Mask = II->getArgOperand(2); 506 break; 507 case Intrinsic::vp_gather: 508 DataType = cast<VectorType>(II->getType()); 509 Ptr = II->getArgOperand(0); 510 MA = II->getParamAlign(0).value_or( 511 DL->getABITypeAlign(DataType->getElementType())); 512 Mask = II->getArgOperand(1); 513 EVL = II->getArgOperand(2); 514 break; 515 case Intrinsic::masked_scatter: 516 DataType = cast<VectorType>(II->getArgOperand(0)->getType()); 517 StoreVal = II->getArgOperand(0); 518 Ptr = II->getArgOperand(1); 519 MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue(); 520 Mask = II->getArgOperand(3); 521 break; 522 case Intrinsic::vp_scatter: 523 DataType = cast<VectorType>(II->getArgOperand(0)->getType()); 524 StoreVal = II->getArgOperand(0); 525 Ptr = II->getArgOperand(1); 526 MA = II->getParamAlign(1).value_or( 527 DL->getABITypeAlign(DataType->getElementType())); 528 Mask = II->getArgOperand(2); 529 EVL = II->getArgOperand(3); 530 break; 531 default: 532 llvm_unreachable("Unexpected intrinsic"); 533 } 534 535 // Make sure the operation will be supported by the backend. 536 EVT DataTypeVT = TLI->getValueType(*DL, DataType); 537 if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) 538 return false; 539 540 // FIXME: Let the backend type legalize by splitting/widening? 541 if (!TLI->isTypeLegal(DataTypeVT)) 542 return false; 543 544 // Pointer should be an instruction. 545 auto *PtrI = dyn_cast<Instruction>(Ptr); 546 if (!PtrI) 547 return false; 548 549 LLVMContext &Ctx = PtrI->getContext(); 550 IRBuilder Builder(Ctx, InstSimplifyFolder(*DL)); 551 Builder.SetInsertPoint(PtrI); 552 553 Value *BasePtr, *Stride; 554 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder); 555 if (!BasePtr) 556 return false; 557 assert(Stride != nullptr); 558 559 Builder.SetInsertPoint(II); 560 561 if (!EVL) 562 EVL = Builder.CreateElementCount( 563 Builder.getInt32Ty(), cast<VectorType>(DataType)->getElementCount()); 564 565 CallInst *Call; 566 567 if (!StoreVal) { 568 Call = Builder.CreateIntrinsic( 569 Intrinsic::experimental_vp_strided_load, 570 {DataType, BasePtr->getType(), Stride->getType()}, 571 {BasePtr, Stride, Mask, EVL}); 572 573 // Merge llvm.masked.gather's passthru 574 if (II->getIntrinsicID() == Intrinsic::masked_gather) 575 Call = Builder.CreateIntrinsic(Intrinsic::vp_select, {DataType}, 576 {Mask, Call, II->getArgOperand(3), EVL}); 577 } else 578 Call = Builder.CreateIntrinsic( 579 Intrinsic::experimental_vp_strided_store, 580 {DataType, BasePtr->getType(), Stride->getType()}, 581 {StoreVal, BasePtr, Stride, Mask, EVL}); 582 583 Call->takeName(II); 584 II->replaceAllUsesWith(Call); 585 II->eraseFromParent(); 586 587 if (PtrI->use_empty()) 588 RecursivelyDeleteTriviallyDeadInstructions(PtrI); 589 590 return true; 591 } 592 593 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 594 if (skipFunction(F)) 595 return false; 596 597 auto &TPC = getAnalysis<TargetPassConfig>(); 598 auto &TM = TPC.getTM<RISCVTargetMachine>(); 599 ST = &TM.getSubtarget<RISCVSubtarget>(F); 600 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 601 return false; 602 603 TLI = ST->getTargetLowering(); 604 DL = &F.getDataLayout(); 605 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 606 607 StridedAddrs.clear(); 608 609 SmallVector<IntrinsicInst *, 4> Worklist; 610 611 bool Changed = false; 612 613 for (BasicBlock &BB : F) { 614 for (Instruction &I : BB) { 615 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 616 if (!II) 617 continue; 618 switch (II->getIntrinsicID()) { 619 case Intrinsic::masked_gather: 620 case Intrinsic::masked_scatter: 621 case Intrinsic::vp_gather: 622 case Intrinsic::vp_scatter: 623 Worklist.push_back(II); 624 break; 625 default: 626 break; 627 } 628 } 629 } 630 631 // Rewrite gather/scatter to form strided load/store if possible. 632 for (auto *II : Worklist) 633 Changed |= tryCreateStridedLoadStore(II); 634 635 // Remove any dead phis. 636 while (!MaybeDeadPHIs.empty()) { 637 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 638 RecursivelyDeleteDeadPHINode(Phi); 639 } 640 641 return Changed; 642 } 643