xref: /llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision a761e26b2364ea457b79b9a4bea6d792e4913d24)
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISC-V intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/InstSimplifyFolder.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/Analysis/VectorUtils.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/GetElementPtrTypeIterator.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Transforms/Utils/Local.h"
26 #include <optional>
27 
28 using namespace llvm;
29 using namespace PatternMatch;
30 
31 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
32 
33 namespace {
34 
35 class RISCVGatherScatterLowering : public FunctionPass {
36   const RISCVSubtarget *ST = nullptr;
37   const RISCVTargetLowering *TLI = nullptr;
38   LoopInfo *LI = nullptr;
39   const DataLayout *DL = nullptr;
40 
41   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
42 
43   // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
44   // used by multiple gathers/scatters, this allow us to reuse the scalar
45   // instructions we created for the first gather/scatter for the others.
46   DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
47 
48 public:
49   static char ID; // Pass identification, replacement for typeid
50 
51   RISCVGatherScatterLowering() : FunctionPass(ID) {}
52 
53   bool runOnFunction(Function &F) override;
54 
55   void getAnalysisUsage(AnalysisUsage &AU) const override {
56     AU.setPreservesCFG();
57     AU.addRequired<TargetPassConfig>();
58     AU.addRequired<LoopInfoWrapperPass>();
59   }
60 
61   StringRef getPassName() const override {
62     return "RISC-V gather/scatter lowering";
63   }
64 
65 private:
66   bool tryCreateStridedLoadStore(IntrinsicInst *II);
67 
68   std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
69                                                      IRBuilderBase &Builder);
70 
71   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
72                               PHINode *&BasePtr, BinaryOperator *&Inc,
73                               IRBuilderBase &Builder);
74 };
75 
76 } // end anonymous namespace
77 
78 char RISCVGatherScatterLowering::ID = 0;
79 
80 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
81                 "RISC-V gather/scatter lowering pass", false, false)
82 
83 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
84   return new RISCVGatherScatterLowering();
85 }
86 
87 // TODO: Should we consider the mask when looking for a stride?
88 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
89   if (!isa<FixedVectorType>(StartC->getType()))
90     return std::make_pair(nullptr, nullptr);
91 
92   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
93 
94   // Check that the start value is a strided constant.
95   auto *StartVal =
96       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
97   if (!StartVal)
98     return std::make_pair(nullptr, nullptr);
99   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
100   ConstantInt *Prev = StartVal;
101   for (unsigned i = 1; i != NumElts; ++i) {
102     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
103     if (!C)
104       return std::make_pair(nullptr, nullptr);
105 
106     APInt LocalStride = C->getValue() - Prev->getValue();
107     if (i == 1)
108       StrideVal = LocalStride;
109     else if (StrideVal != LocalStride)
110       return std::make_pair(nullptr, nullptr);
111 
112     Prev = C;
113   }
114 
115   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
116 
117   return std::make_pair(StartVal, Stride);
118 }
119 
120 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
121                                                      IRBuilderBase &Builder) {
122   // Base case, start is a strided constant.
123   auto *StartC = dyn_cast<Constant>(Start);
124   if (StartC)
125     return matchStridedConstant(StartC);
126 
127   // Base case, start is a stepvector
128   if (match(Start, m_Intrinsic<Intrinsic::stepvector>())) {
129     auto *Ty = Start->getType()->getScalarType();
130     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
131   }
132 
133   // Not a constant, maybe it's a strided constant with a splat added or
134   // multipled.
135   auto *BO = dyn_cast<BinaryOperator>(Start);
136   if (!BO || (BO->getOpcode() != Instruction::Add &&
137               BO->getOpcode() != Instruction::Or &&
138               BO->getOpcode() != Instruction::Shl &&
139               BO->getOpcode() != Instruction::Mul))
140     return std::make_pair(nullptr, nullptr);
141 
142   if (BO->getOpcode() == Instruction::Or &&
143       !cast<PossiblyDisjointInst>(BO)->isDisjoint())
144     return std::make_pair(nullptr, nullptr);
145 
146   // Look for an operand that is splatted.
147   unsigned OtherIndex = 0;
148   Value *Splat = getSplatValue(BO->getOperand(1));
149   if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
150     Splat = getSplatValue(BO->getOperand(0));
151     OtherIndex = 1;
152   }
153   if (!Splat)
154     return std::make_pair(nullptr, nullptr);
155 
156   Value *Stride;
157   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
158                                               Builder);
159   if (!Start)
160     return std::make_pair(nullptr, nullptr);
161 
162   Builder.SetInsertPoint(BO);
163   Builder.SetCurrentDebugLocation(DebugLoc());
164   // Add the splat value to the start or multiply the start and stride by the
165   // splat.
166   switch (BO->getOpcode()) {
167   default:
168     llvm_unreachable("Unexpected opcode");
169   case Instruction::Or:
170     // TODO: We'd be better off creating disjoint or here, but we don't yet
171     // have an IRBuilder API for that.
172     [[fallthrough]];
173   case Instruction::Add:
174     Start = Builder.CreateAdd(Start, Splat);
175     break;
176   case Instruction::Mul:
177     Start = Builder.CreateMul(Start, Splat);
178     Stride = Builder.CreateMul(Stride, Splat);
179     break;
180   case Instruction::Shl:
181     Start = Builder.CreateShl(Start, Splat);
182     Stride = Builder.CreateShl(Stride, Splat);
183     break;
184   }
185 
186   return std::make_pair(Start, Stride);
187 }
188 
189 // Recursively, walk about the use-def chain until we find a Phi with a strided
190 // start value. Build and update a scalar recurrence as we unwind the recursion.
191 // We also update the Stride as we unwind. Our goal is to move all of the
192 // arithmetic out of the loop.
193 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
194                                                         Value *&Stride,
195                                                         PHINode *&BasePtr,
196                                                         BinaryOperator *&Inc,
197                                                         IRBuilderBase &Builder) {
198   // Our base case is a Phi.
199   if (auto *Phi = dyn_cast<PHINode>(Index)) {
200     // A phi node we want to perform this function on should be from the
201     // loop header.
202     if (Phi->getParent() != L->getHeader())
203       return false;
204 
205     Value *Step, *Start;
206     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
207         Inc->getOpcode() != Instruction::Add)
208       return false;
209     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
210     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
211     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
212            "Expected one operand of phi to be Inc");
213 
214     // Step should be a splat.
215     Step = getSplatValue(Step);
216     if (!Step)
217       return false;
218 
219     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
220     if (!Start)
221       return false;
222     assert(Stride != nullptr);
223 
224     // Build scalar phi and increment.
225     BasePtr =
226         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator());
227     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
228                                     Inc->getIterator());
229     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
230     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
231 
232     // Note that this Phi might be eligible for removal.
233     MaybeDeadPHIs.push_back(Phi);
234     return true;
235   }
236 
237   // Otherwise look for binary operator.
238   auto *BO = dyn_cast<BinaryOperator>(Index);
239   if (!BO)
240     return false;
241 
242   switch (BO->getOpcode()) {
243   default:
244     return false;
245   case Instruction::Or:
246     // We need to be able to treat Or as Add.
247     if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
248       return false;
249     break;
250   case Instruction::Add:
251     break;
252   case Instruction::Shl:
253     break;
254   case Instruction::Mul:
255     break;
256   }
257 
258   // We should have one operand in the loop and one splat.
259   Value *OtherOp;
260   if (isa<Instruction>(BO->getOperand(0)) &&
261       L->contains(cast<Instruction>(BO->getOperand(0)))) {
262     Index = cast<Instruction>(BO->getOperand(0));
263     OtherOp = BO->getOperand(1);
264   } else if (isa<Instruction>(BO->getOperand(1)) &&
265              L->contains(cast<Instruction>(BO->getOperand(1))) &&
266              Instruction::isCommutative(BO->getOpcode())) {
267     Index = cast<Instruction>(BO->getOperand(1));
268     OtherOp = BO->getOperand(0);
269   } else {
270     return false;
271   }
272 
273   // Make sure other op is loop invariant.
274   if (!L->isLoopInvariant(OtherOp))
275     return false;
276 
277   // Make sure we have a splat.
278   Value *SplatOp = getSplatValue(OtherOp);
279   if (!SplatOp)
280     return false;
281 
282   // Recurse up the use-def chain.
283   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
284     return false;
285 
286   // Locate the Step and Start values from the recurrence.
287   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
288   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
289   Value *Step = Inc->getOperand(StepIndex);
290   Value *Start = BasePtr->getOperand(StartBlock);
291 
292   // We need to adjust the start value in the preheader.
293   Builder.SetInsertPoint(
294       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
295   Builder.SetCurrentDebugLocation(DebugLoc());
296 
297   // TODO: Share this switch with matchStridedStart?
298   switch (BO->getOpcode()) {
299   default:
300     llvm_unreachable("Unexpected opcode!");
301   case Instruction::Add:
302   case Instruction::Or: {
303     // An add only affects the start value. It's ok to do this for Or because
304     // we already checked that there are no common set bits.
305     Start = Builder.CreateAdd(Start, SplatOp, "start");
306     break;
307   }
308   case Instruction::Mul: {
309     Start = Builder.CreateMul(Start, SplatOp, "start");
310     Stride = Builder.CreateMul(Stride, SplatOp, "stride");
311     break;
312   }
313   case Instruction::Shl: {
314     Start = Builder.CreateShl(Start, SplatOp, "start");
315     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
316     break;
317   }
318   }
319 
320   // If the Step was defined inside the loop, adjust it before its definition
321   // instead of in the preheader.
322   if (auto *StepI = dyn_cast<Instruction>(Step); StepI && L->contains(StepI))
323     Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());
324 
325   switch (BO->getOpcode()) {
326   default:
327     break;
328   case Instruction::Mul:
329     Step = Builder.CreateMul(Step, SplatOp, "step");
330     break;
331   case Instruction::Shl:
332     Step = Builder.CreateShl(Step, SplatOp, "step");
333     break;
334   }
335 
336   Inc->setOperand(StepIndex, Step);
337   BasePtr->setIncomingValue(StartBlock, Start);
338   return true;
339 }
340 
341 std::pair<Value *, Value *>
342 RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
343                                                    IRBuilderBase &Builder) {
344 
345   // A gather/scatter of a splat is a zero strided load/store.
346   if (auto *BasePtr = getSplatValue(Ptr)) {
347     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
348     return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
349   }
350 
351   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
352   if (!GEP)
353     return std::make_pair(nullptr, nullptr);
354 
355   auto I = StridedAddrs.find(GEP);
356   if (I != StridedAddrs.end())
357     return I->second;
358 
359   SmallVector<Value *, 2> Ops(GEP->operands());
360 
361   // If the base pointer is a vector, check if it's strided.
362   Value *Base = GEP->getPointerOperand();
363   if (auto *BaseInst = dyn_cast<Instruction>(Base);
364       BaseInst && BaseInst->getType()->isVectorTy()) {
365     // If GEP's offset is scalar then we can add it to the base pointer's base.
366     auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };
367     if (all_of(GEP->indices(), IsScalar)) {
368       auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
369       if (BaseBase) {
370         Builder.SetInsertPoint(GEP);
371         SmallVector<Value *> Indices(GEP->indices());
372         Value *OffsetBase =
373             Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,
374                               GEP->getName() + "offset", GEP->isInBounds());
375         return {OffsetBase, Stride};
376       }
377     }
378   }
379 
380   // Base pointer needs to be a scalar.
381   Value *ScalarBase = Base;
382   if (ScalarBase->getType()->isVectorTy()) {
383     ScalarBase = getSplatValue(ScalarBase);
384     if (!ScalarBase)
385       return std::make_pair(nullptr, nullptr);
386   }
387 
388   std::optional<unsigned> VecOperand;
389   unsigned TypeScale = 0;
390 
391   // Look for a vector operand and scale.
392   gep_type_iterator GTI = gep_type_begin(GEP);
393   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
394     if (!Ops[i]->getType()->isVectorTy())
395       continue;
396 
397     if (VecOperand)
398       return std::make_pair(nullptr, nullptr);
399 
400     VecOperand = i;
401 
402     TypeSize TS = GTI.getSequentialElementStride(*DL);
403     if (TS.isScalable())
404       return std::make_pair(nullptr, nullptr);
405 
406     TypeScale = TS.getFixedValue();
407   }
408 
409   // We need to find a vector index to simplify.
410   if (!VecOperand)
411     return std::make_pair(nullptr, nullptr);
412 
413   // We can't extract the stride if the arithmetic is done at a different size
414   // than the pointer type. Adding the stride later may not wrap correctly.
415   // Technically we could handle wider indices, but I don't expect that in
416   // practice.  Handle one special case here - constants.  This simplifies
417   // writing test cases.
418   Value *VecIndex = Ops[*VecOperand];
419   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
420   if (VecIndex->getType() != VecIntPtrTy) {
421     auto *VecIndexC = dyn_cast<Constant>(VecIndex);
422     if (!VecIndexC)
423       return std::make_pair(nullptr, nullptr);
424     if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
425       VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
426     else
427       VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
428   }
429 
430   // Handle the non-recursive case.  This is what we see if the vectorizer
431   // decides to use a scalar IV + vid on demand instead of a vector IV.
432   auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
433   if (Start) {
434     assert(Stride);
435     Builder.SetInsertPoint(GEP);
436 
437     // Replace the vector index with the scalar start and build a scalar GEP.
438     Ops[*VecOperand] = Start;
439     Type *SourceTy = GEP->getSourceElementType();
440     Value *BasePtr =
441         Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
442 
443     // Convert stride to pointer size if needed.
444     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
445     assert(Stride->getType() == IntPtrTy && "Unexpected type");
446 
447     // Scale the stride by the size of the indexed type.
448     if (TypeScale != 1)
449       Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
450 
451     auto P = std::make_pair(BasePtr, Stride);
452     StridedAddrs[GEP] = P;
453     return P;
454   }
455 
456   // Make sure we're in a loop and that has a pre-header and a single latch.
457   Loop *L = LI->getLoopFor(GEP->getParent());
458   if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
459     return std::make_pair(nullptr, nullptr);
460 
461   BinaryOperator *Inc;
462   PHINode *BasePhi;
463   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
464     return std::make_pair(nullptr, nullptr);
465 
466   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
467   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
468   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
469          "Expected one operand of phi to be Inc");
470 
471   Builder.SetInsertPoint(GEP);
472 
473   // Replace the vector index with the scalar phi and build a scalar GEP.
474   Ops[*VecOperand] = BasePhi;
475   Type *SourceTy = GEP->getSourceElementType();
476   Value *BasePtr =
477       Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
478 
479   // Final adjustments to stride should go in the start block.
480   Builder.SetInsertPoint(
481       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
482 
483   // Convert stride to pointer size if needed.
484   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
485   assert(Stride->getType() == IntPtrTy && "Unexpected type");
486 
487   // Scale the stride by the size of the indexed type.
488   if (TypeScale != 1)
489     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
490 
491   auto P = std::make_pair(BasePtr, Stride);
492   StridedAddrs[GEP] = P;
493   return P;
494 }
495 
496 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) {
497   VectorType *DataType;
498   Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr;
499   MaybeAlign MA;
500   switch (II->getIntrinsicID()) {
501   case Intrinsic::masked_gather:
502     DataType = cast<VectorType>(II->getType());
503     Ptr = II->getArgOperand(0);
504     MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue();
505     Mask = II->getArgOperand(2);
506     break;
507   case Intrinsic::vp_gather:
508     DataType = cast<VectorType>(II->getType());
509     Ptr = II->getArgOperand(0);
510     MA = II->getParamAlign(0).value_or(
511         DL->getABITypeAlign(DataType->getElementType()));
512     Mask = II->getArgOperand(1);
513     EVL = II->getArgOperand(2);
514     break;
515   case Intrinsic::masked_scatter:
516     DataType = cast<VectorType>(II->getArgOperand(0)->getType());
517     StoreVal = II->getArgOperand(0);
518     Ptr = II->getArgOperand(1);
519     MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue();
520     Mask = II->getArgOperand(3);
521     break;
522   case Intrinsic::vp_scatter:
523     DataType = cast<VectorType>(II->getArgOperand(0)->getType());
524     StoreVal = II->getArgOperand(0);
525     Ptr = II->getArgOperand(1);
526     MA = II->getParamAlign(1).value_or(
527         DL->getABITypeAlign(DataType->getElementType()));
528     Mask = II->getArgOperand(2);
529     EVL = II->getArgOperand(3);
530     break;
531   default:
532     llvm_unreachable("Unexpected intrinsic");
533   }
534 
535   // Make sure the operation will be supported by the backend.
536   EVT DataTypeVT = TLI->getValueType(*DL, DataType);
537   if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
538     return false;
539 
540   // FIXME: Let the backend type legalize by splitting/widening?
541   if (!TLI->isTypeLegal(DataTypeVT))
542     return false;
543 
544   // Pointer should be an instruction.
545   auto *PtrI = dyn_cast<Instruction>(Ptr);
546   if (!PtrI)
547     return false;
548 
549   LLVMContext &Ctx = PtrI->getContext();
550   IRBuilder Builder(Ctx, InstSimplifyFolder(*DL));
551   Builder.SetInsertPoint(PtrI);
552 
553   Value *BasePtr, *Stride;
554   std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
555   if (!BasePtr)
556     return false;
557   assert(Stride != nullptr);
558 
559   Builder.SetInsertPoint(II);
560 
561   if (!EVL)
562     EVL = Builder.CreateElementCount(
563         Builder.getInt32Ty(), cast<VectorType>(DataType)->getElementCount());
564 
565   CallInst *Call;
566 
567   if (!StoreVal) {
568     Call = Builder.CreateIntrinsic(
569         Intrinsic::experimental_vp_strided_load,
570         {DataType, BasePtr->getType(), Stride->getType()},
571         {BasePtr, Stride, Mask, EVL});
572 
573     // Merge llvm.masked.gather's passthru
574     if (II->getIntrinsicID() == Intrinsic::masked_gather)
575       Call = Builder.CreateIntrinsic(Intrinsic::vp_select, {DataType},
576                                      {Mask, Call, II->getArgOperand(3), EVL});
577   } else
578     Call = Builder.CreateIntrinsic(
579         Intrinsic::experimental_vp_strided_store,
580         {DataType, BasePtr->getType(), Stride->getType()},
581         {StoreVal, BasePtr, Stride, Mask, EVL});
582 
583   Call->takeName(II);
584   II->replaceAllUsesWith(Call);
585   II->eraseFromParent();
586 
587   if (PtrI->use_empty())
588     RecursivelyDeleteTriviallyDeadInstructions(PtrI);
589 
590   return true;
591 }
592 
593 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
594   if (skipFunction(F))
595     return false;
596 
597   auto &TPC = getAnalysis<TargetPassConfig>();
598   auto &TM = TPC.getTM<RISCVTargetMachine>();
599   ST = &TM.getSubtarget<RISCVSubtarget>(F);
600   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
601     return false;
602 
603   TLI = ST->getTargetLowering();
604   DL = &F.getDataLayout();
605   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
606 
607   StridedAddrs.clear();
608 
609   SmallVector<IntrinsicInst *, 4> Worklist;
610 
611   bool Changed = false;
612 
613   for (BasicBlock &BB : F) {
614     for (Instruction &I : BB) {
615       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
616       if (!II)
617         continue;
618       switch (II->getIntrinsicID()) {
619       case Intrinsic::masked_gather:
620       case Intrinsic::masked_scatter:
621       case Intrinsic::vp_gather:
622       case Intrinsic::vp_scatter:
623         Worklist.push_back(II);
624         break;
625       default:
626         break;
627       }
628     }
629   }
630 
631   // Rewrite gather/scatter to form strided load/store if possible.
632   for (auto *II : Worklist)
633     Changed |= tryCreateStridedLoadStore(II);
634 
635   // Remove any dead phis.
636   while (!MaybeDeadPHIs.empty()) {
637     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
638       RecursivelyDeleteDeadPHINode(Phi);
639   }
640 
641   return Changed;
642 }
643