xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2349cc55cSDimitry Andric //
3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6349cc55cSDimitry Andric //
7349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
8349cc55cSDimitry Andric //
9349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to
1006c3fb27SDimitry Andric // RISC-V intrinsics.
11349cc55cSDimitry Andric //
12349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
13349cc55cSDimitry Andric 
14349cc55cSDimitry Andric #include "RISCV.h"
15349cc55cSDimitry Andric #include "RISCVTargetMachine.h"
1606c3fb27SDimitry Andric #include "llvm/Analysis/InstSimplifyFolder.h"
17349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
18349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
19349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h"
20349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
21349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h"
22349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h"
23349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
24349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h"
25bdd1243dSDimitry Andric #include "llvm/IR/PatternMatch.h"
26349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
27bdd1243dSDimitry Andric #include <optional>
28349cc55cSDimitry Andric 
29349cc55cSDimitry Andric using namespace llvm;
30bdd1243dSDimitry Andric using namespace PatternMatch;
31349cc55cSDimitry Andric 
32349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering"
33349cc55cSDimitry Andric 
34349cc55cSDimitry Andric namespace {
35349cc55cSDimitry Andric 
36349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass {
37349cc55cSDimitry Andric   const RISCVSubtarget *ST = nullptr;
38349cc55cSDimitry Andric   const RISCVTargetLowering *TLI = nullptr;
39349cc55cSDimitry Andric   LoopInfo *LI = nullptr;
40349cc55cSDimitry Andric   const DataLayout *DL = nullptr;
41349cc55cSDimitry Andric 
42349cc55cSDimitry Andric   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
43349cc55cSDimitry Andric 
4481ad6265SDimitry Andric   // Cache of the BasePtr and Stride determined from this GEP. When a GEP is
4581ad6265SDimitry Andric   // used by multiple gathers/scatters, this allow us to reuse the scalar
4681ad6265SDimitry Andric   // instructions we created for the first gather/scatter for the others.
4781ad6265SDimitry Andric   DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
4881ad6265SDimitry Andric 
49349cc55cSDimitry Andric public:
50349cc55cSDimitry Andric   static char ID; // Pass identification, replacement for typeid
51349cc55cSDimitry Andric 
52349cc55cSDimitry Andric   RISCVGatherScatterLowering() : FunctionPass(ID) {}
53349cc55cSDimitry Andric 
54349cc55cSDimitry Andric   bool runOnFunction(Function &F) override;
55349cc55cSDimitry Andric 
56349cc55cSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
57349cc55cSDimitry Andric     AU.setPreservesCFG();
58349cc55cSDimitry Andric     AU.addRequired<TargetPassConfig>();
59349cc55cSDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
60349cc55cSDimitry Andric   }
61349cc55cSDimitry Andric 
62349cc55cSDimitry Andric   StringRef getPassName() const override {
6306c3fb27SDimitry Andric     return "RISC-V gather/scatter lowering";
64349cc55cSDimitry Andric   }
65349cc55cSDimitry Andric 
66349cc55cSDimitry Andric private:
67349cc55cSDimitry Andric   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
68349cc55cSDimitry Andric                                  Value *AlignOp);
69349cc55cSDimitry Andric 
705f757f3fSDimitry Andric   std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
7106c3fb27SDimitry Andric                                                      IRBuilderBase &Builder);
72349cc55cSDimitry Andric 
73349cc55cSDimitry Andric   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74349cc55cSDimitry Andric                               PHINode *&BasePtr, BinaryOperator *&Inc,
7506c3fb27SDimitry Andric                               IRBuilderBase &Builder);
76349cc55cSDimitry Andric };
77349cc55cSDimitry Andric 
78349cc55cSDimitry Andric } // end anonymous namespace
79349cc55cSDimitry Andric 
80349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0;
81349cc55cSDimitry Andric 
82349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
8306c3fb27SDimitry Andric                 "RISC-V gather/scatter lowering pass", false, false)
84349cc55cSDimitry Andric 
85349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
86349cc55cSDimitry Andric   return new RISCVGatherScatterLowering();
87349cc55cSDimitry Andric }
88349cc55cSDimitry Andric 
89349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride?
90349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
9106c3fb27SDimitry Andric   if (!isa<FixedVectorType>(StartC->getType()))
9206c3fb27SDimitry Andric     return std::make_pair(nullptr, nullptr);
9306c3fb27SDimitry Andric 
94349cc55cSDimitry Andric   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
95349cc55cSDimitry Andric 
96349cc55cSDimitry Andric   // Check that the start value is a strided constant.
97349cc55cSDimitry Andric   auto *StartVal =
98349cc55cSDimitry Andric       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
99349cc55cSDimitry Andric   if (!StartVal)
100349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
101349cc55cSDimitry Andric   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
102349cc55cSDimitry Andric   ConstantInt *Prev = StartVal;
103349cc55cSDimitry Andric   for (unsigned i = 1; i != NumElts; ++i) {
104349cc55cSDimitry Andric     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
105349cc55cSDimitry Andric     if (!C)
106349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
107349cc55cSDimitry Andric 
108349cc55cSDimitry Andric     APInt LocalStride = C->getValue() - Prev->getValue();
109349cc55cSDimitry Andric     if (i == 1)
110349cc55cSDimitry Andric       StrideVal = LocalStride;
111349cc55cSDimitry Andric     else if (StrideVal != LocalStride)
112349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
113349cc55cSDimitry Andric 
114349cc55cSDimitry Andric     Prev = C;
115349cc55cSDimitry Andric   }
116349cc55cSDimitry Andric 
117349cc55cSDimitry Andric   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
118349cc55cSDimitry Andric 
119349cc55cSDimitry Andric   return std::make_pair(StartVal, Stride);
120349cc55cSDimitry Andric }
121349cc55cSDimitry Andric 
12204eeddc0SDimitry Andric static std::pair<Value *, Value *> matchStridedStart(Value *Start,
12306c3fb27SDimitry Andric                                                      IRBuilderBase &Builder) {
12404eeddc0SDimitry Andric   // Base case, start is a strided constant.
12504eeddc0SDimitry Andric   auto *StartC = dyn_cast<Constant>(Start);
12604eeddc0SDimitry Andric   if (StartC)
12704eeddc0SDimitry Andric     return matchStridedConstant(StartC);
12804eeddc0SDimitry Andric 
129bdd1243dSDimitry Andric   // Base case, start is a stepvector
130bdd1243dSDimitry Andric   if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
131bdd1243dSDimitry Andric     auto *Ty = Start->getType()->getScalarType();
132bdd1243dSDimitry Andric     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
133bdd1243dSDimitry Andric   }
134bdd1243dSDimitry Andric 
13506c3fb27SDimitry Andric   // Not a constant, maybe it's a strided constant with a splat added or
13606c3fb27SDimitry Andric   // multipled.
13704eeddc0SDimitry Andric   auto *BO = dyn_cast<BinaryOperator>(Start);
13806c3fb27SDimitry Andric   if (!BO || (BO->getOpcode() != Instruction::Add &&
13906c3fb27SDimitry Andric               BO->getOpcode() != Instruction::Shl &&
14006c3fb27SDimitry Andric               BO->getOpcode() != Instruction::Mul))
14104eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
14204eeddc0SDimitry Andric 
14304eeddc0SDimitry Andric   // Look for an operand that is splatted.
14406c3fb27SDimitry Andric   unsigned OtherIndex = 0;
14506c3fb27SDimitry Andric   Value *Splat = getSplatValue(BO->getOperand(1));
14606c3fb27SDimitry Andric   if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
14706c3fb27SDimitry Andric     Splat = getSplatValue(BO->getOperand(0));
14806c3fb27SDimitry Andric     OtherIndex = 1;
14904eeddc0SDimitry Andric   }
15004eeddc0SDimitry Andric   if (!Splat)
15104eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
15204eeddc0SDimitry Andric 
15304eeddc0SDimitry Andric   Value *Stride;
15404eeddc0SDimitry Andric   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
15504eeddc0SDimitry Andric                                               Builder);
15604eeddc0SDimitry Andric   if (!Start)
15704eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
15804eeddc0SDimitry Andric 
15904eeddc0SDimitry Andric   Builder.SetInsertPoint(BO);
16004eeddc0SDimitry Andric   Builder.SetCurrentDebugLocation(DebugLoc());
16106c3fb27SDimitry Andric   // Add the splat value to the start or multiply the start and stride by the
16206c3fb27SDimitry Andric   // splat.
16306c3fb27SDimitry Andric   switch (BO->getOpcode()) {
16406c3fb27SDimitry Andric   default:
16506c3fb27SDimitry Andric     llvm_unreachable("Unexpected opcode");
16606c3fb27SDimitry Andric   case Instruction::Add:
16704eeddc0SDimitry Andric     Start = Builder.CreateAdd(Start, Splat);
16806c3fb27SDimitry Andric     break;
16906c3fb27SDimitry Andric   case Instruction::Mul:
17006c3fb27SDimitry Andric     Start = Builder.CreateMul(Start, Splat);
17106c3fb27SDimitry Andric     Stride = Builder.CreateMul(Stride, Splat);
17206c3fb27SDimitry Andric     break;
17306c3fb27SDimitry Andric   case Instruction::Shl:
17406c3fb27SDimitry Andric     Start = Builder.CreateShl(Start, Splat);
17506c3fb27SDimitry Andric     Stride = Builder.CreateShl(Stride, Splat);
17606c3fb27SDimitry Andric     break;
17706c3fb27SDimitry Andric   }
17806c3fb27SDimitry Andric 
17904eeddc0SDimitry Andric   return std::make_pair(Start, Stride);
18004eeddc0SDimitry Andric }
18104eeddc0SDimitry Andric 
182349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided
183349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion.
184349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the
185349cc55cSDimitry Andric // arithmetic out of the loop.
186349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
187349cc55cSDimitry Andric                                                         Value *&Stride,
188349cc55cSDimitry Andric                                                         PHINode *&BasePtr,
189349cc55cSDimitry Andric                                                         BinaryOperator *&Inc,
19006c3fb27SDimitry Andric                                                         IRBuilderBase &Builder) {
191349cc55cSDimitry Andric   // Our base case is a Phi.
192349cc55cSDimitry Andric   if (auto *Phi = dyn_cast<PHINode>(Index)) {
193349cc55cSDimitry Andric     // A phi node we want to perform this function on should be from the
194349cc55cSDimitry Andric     // loop header.
195349cc55cSDimitry Andric     if (Phi->getParent() != L->getHeader())
196349cc55cSDimitry Andric       return false;
197349cc55cSDimitry Andric 
198349cc55cSDimitry Andric     Value *Step, *Start;
199349cc55cSDimitry Andric     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
200349cc55cSDimitry Andric         Inc->getOpcode() != Instruction::Add)
201349cc55cSDimitry Andric       return false;
202349cc55cSDimitry Andric     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
203349cc55cSDimitry Andric     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
204349cc55cSDimitry Andric     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
205349cc55cSDimitry Andric            "Expected one operand of phi to be Inc");
206349cc55cSDimitry Andric 
207349cc55cSDimitry Andric     // Only proceed if the step is loop invariant.
208349cc55cSDimitry Andric     if (!L->isLoopInvariant(Step))
209349cc55cSDimitry Andric       return false;
210349cc55cSDimitry Andric 
211349cc55cSDimitry Andric     // Step should be a splat.
212349cc55cSDimitry Andric     Step = getSplatValue(Step);
213349cc55cSDimitry Andric     if (!Step)
214349cc55cSDimitry Andric       return false;
215349cc55cSDimitry Andric 
21604eeddc0SDimitry Andric     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
217349cc55cSDimitry Andric     if (!Start)
218349cc55cSDimitry Andric       return false;
219349cc55cSDimitry Andric     assert(Stride != nullptr);
220349cc55cSDimitry Andric 
221349cc55cSDimitry Andric     // Build scalar phi and increment.
222349cc55cSDimitry Andric     BasePtr =
223349cc55cSDimitry Andric         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
224349cc55cSDimitry Andric     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
225349cc55cSDimitry Andric                                     Inc);
226349cc55cSDimitry Andric     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
227349cc55cSDimitry Andric     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
228349cc55cSDimitry Andric 
229349cc55cSDimitry Andric     // Note that this Phi might be eligible for removal.
230349cc55cSDimitry Andric     MaybeDeadPHIs.push_back(Phi);
231349cc55cSDimitry Andric     return true;
232349cc55cSDimitry Andric   }
233349cc55cSDimitry Andric 
234349cc55cSDimitry Andric   // Otherwise look for binary operator.
235349cc55cSDimitry Andric   auto *BO = dyn_cast<BinaryOperator>(Index);
236349cc55cSDimitry Andric   if (!BO)
237349cc55cSDimitry Andric     return false;
238349cc55cSDimitry Andric 
23906c3fb27SDimitry Andric   switch (BO->getOpcode()) {
24006c3fb27SDimitry Andric   default:
241349cc55cSDimitry Andric     return false;
24206c3fb27SDimitry Andric   case Instruction::Or:
243349cc55cSDimitry Andric     // We need to be able to treat Or as Add.
24406c3fb27SDimitry Andric     if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
245349cc55cSDimitry Andric       return false;
24606c3fb27SDimitry Andric     break;
24706c3fb27SDimitry Andric   case Instruction::Add:
24806c3fb27SDimitry Andric     break;
24906c3fb27SDimitry Andric   case Instruction::Shl:
25006c3fb27SDimitry Andric     break;
25106c3fb27SDimitry Andric   case Instruction::Mul:
25206c3fb27SDimitry Andric     break;
25306c3fb27SDimitry Andric   }
254349cc55cSDimitry Andric 
255349cc55cSDimitry Andric   // We should have one operand in the loop and one splat.
256349cc55cSDimitry Andric   Value *OtherOp;
257349cc55cSDimitry Andric   if (isa<Instruction>(BO->getOperand(0)) &&
258349cc55cSDimitry Andric       L->contains(cast<Instruction>(BO->getOperand(0)))) {
259349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(0));
260349cc55cSDimitry Andric     OtherOp = BO->getOperand(1);
261349cc55cSDimitry Andric   } else if (isa<Instruction>(BO->getOperand(1)) &&
26206c3fb27SDimitry Andric              L->contains(cast<Instruction>(BO->getOperand(1))) &&
26306c3fb27SDimitry Andric              Instruction::isCommutative(BO->getOpcode())) {
264349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(1));
265349cc55cSDimitry Andric     OtherOp = BO->getOperand(0);
266349cc55cSDimitry Andric   } else {
267349cc55cSDimitry Andric     return false;
268349cc55cSDimitry Andric   }
269349cc55cSDimitry Andric 
270349cc55cSDimitry Andric   // Make sure other op is loop invariant.
271349cc55cSDimitry Andric   if (!L->isLoopInvariant(OtherOp))
272349cc55cSDimitry Andric     return false;
273349cc55cSDimitry Andric 
274349cc55cSDimitry Andric   // Make sure we have a splat.
275349cc55cSDimitry Andric   Value *SplatOp = getSplatValue(OtherOp);
276349cc55cSDimitry Andric   if (!SplatOp)
277349cc55cSDimitry Andric     return false;
278349cc55cSDimitry Andric 
279349cc55cSDimitry Andric   // Recurse up the use-def chain.
280349cc55cSDimitry Andric   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
281349cc55cSDimitry Andric     return false;
282349cc55cSDimitry Andric 
283349cc55cSDimitry Andric   // Locate the Step and Start values from the recurrence.
284349cc55cSDimitry Andric   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
285349cc55cSDimitry Andric   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
286349cc55cSDimitry Andric   Value *Step = Inc->getOperand(StepIndex);
287349cc55cSDimitry Andric   Value *Start = BasePtr->getOperand(StartBlock);
288349cc55cSDimitry Andric 
289349cc55cSDimitry Andric   // We need to adjust the start value in the preheader.
290349cc55cSDimitry Andric   Builder.SetInsertPoint(
291349cc55cSDimitry Andric       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
292349cc55cSDimitry Andric   Builder.SetCurrentDebugLocation(DebugLoc());
293349cc55cSDimitry Andric 
294349cc55cSDimitry Andric   switch (BO->getOpcode()) {
295349cc55cSDimitry Andric   default:
296349cc55cSDimitry Andric     llvm_unreachable("Unexpected opcode!");
297349cc55cSDimitry Andric   case Instruction::Add:
298349cc55cSDimitry Andric   case Instruction::Or: {
299349cc55cSDimitry Andric     // An add only affects the start value. It's ok to do this for Or because
300349cc55cSDimitry Andric     // we already checked that there are no common set bits.
301349cc55cSDimitry Andric     Start = Builder.CreateAdd(Start, SplatOp, "start");
302349cc55cSDimitry Andric     break;
303349cc55cSDimitry Andric   }
304349cc55cSDimitry Andric   case Instruction::Mul: {
305349cc55cSDimitry Andric     Start = Builder.CreateMul(Start, SplatOp, "start");
306349cc55cSDimitry Andric     Step = Builder.CreateMul(Step, SplatOp, "step");
307349cc55cSDimitry Andric     Stride = Builder.CreateMul(Stride, SplatOp, "stride");
308349cc55cSDimitry Andric     break;
309349cc55cSDimitry Andric   }
310349cc55cSDimitry Andric   case Instruction::Shl: {
311349cc55cSDimitry Andric     Start = Builder.CreateShl(Start, SplatOp, "start");
312349cc55cSDimitry Andric     Step = Builder.CreateShl(Step, SplatOp, "step");
313349cc55cSDimitry Andric     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
314349cc55cSDimitry Andric     break;
315349cc55cSDimitry Andric   }
316349cc55cSDimitry Andric   }
317349cc55cSDimitry Andric 
31806c3fb27SDimitry Andric   Inc->setOperand(StepIndex, Step);
31906c3fb27SDimitry Andric   BasePtr->setIncomingValue(StartBlock, Start);
320349cc55cSDimitry Andric   return true;
321349cc55cSDimitry Andric }
322349cc55cSDimitry Andric 
323349cc55cSDimitry Andric std::pair<Value *, Value *>
3245f757f3fSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
32506c3fb27SDimitry Andric                                                    IRBuilderBase &Builder) {
326349cc55cSDimitry Andric 
3275f757f3fSDimitry Andric   // A gather/scatter of a splat is a zero strided load/store.
3285f757f3fSDimitry Andric   if (auto *BasePtr = getSplatValue(Ptr)) {
3295f757f3fSDimitry Andric     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
3305f757f3fSDimitry Andric     return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
3315f757f3fSDimitry Andric   }
3325f757f3fSDimitry Andric 
3335f757f3fSDimitry Andric   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
3345f757f3fSDimitry Andric   if (!GEP)
3355f757f3fSDimitry Andric     return std::make_pair(nullptr, nullptr);
3365f757f3fSDimitry Andric 
33781ad6265SDimitry Andric   auto I = StridedAddrs.find(GEP);
33881ad6265SDimitry Andric   if (I != StridedAddrs.end())
33981ad6265SDimitry Andric     return I->second;
34081ad6265SDimitry Andric 
341349cc55cSDimitry Andric   SmallVector<Value *, 2> Ops(GEP->operands());
342349cc55cSDimitry Andric 
343349cc55cSDimitry Andric   // Base pointer needs to be a scalar.
3445f757f3fSDimitry Andric   Value *ScalarBase = Ops[0];
3455f757f3fSDimitry Andric   if (ScalarBase->getType()->isVectorTy()) {
3465f757f3fSDimitry Andric     ScalarBase = getSplatValue(ScalarBase);
3475f757f3fSDimitry Andric     if (!ScalarBase)
348349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
3495f757f3fSDimitry Andric   }
350349cc55cSDimitry Andric 
351bdd1243dSDimitry Andric   std::optional<unsigned> VecOperand;
352349cc55cSDimitry Andric   unsigned TypeScale = 0;
353349cc55cSDimitry Andric 
354349cc55cSDimitry Andric   // Look for a vector operand and scale.
355349cc55cSDimitry Andric   gep_type_iterator GTI = gep_type_begin(GEP);
356349cc55cSDimitry Andric   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
357349cc55cSDimitry Andric     if (!Ops[i]->getType()->isVectorTy())
358349cc55cSDimitry Andric       continue;
359349cc55cSDimitry Andric 
360349cc55cSDimitry Andric     if (VecOperand)
361349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
362349cc55cSDimitry Andric 
363349cc55cSDimitry Andric     VecOperand = i;
364349cc55cSDimitry Andric 
365*1db9f3b2SDimitry Andric     TypeSize TS = GTI.getSequentialElementStride(*DL);
366349cc55cSDimitry Andric     if (TS.isScalable())
367349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
368349cc55cSDimitry Andric 
369bdd1243dSDimitry Andric     TypeScale = TS.getFixedValue();
370349cc55cSDimitry Andric   }
371349cc55cSDimitry Andric 
372349cc55cSDimitry Andric   // We need to find a vector index to simplify.
373349cc55cSDimitry Andric   if (!VecOperand)
374349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
375349cc55cSDimitry Andric 
376349cc55cSDimitry Andric   // We can't extract the stride if the arithmetic is done at a different size
377349cc55cSDimitry Andric   // than the pointer type. Adding the stride later may not wrap correctly.
378349cc55cSDimitry Andric   // Technically we could handle wider indices, but I don't expect that in
3795f757f3fSDimitry Andric   // practice.  Handle one special case here - constants.  This simplifies
3805f757f3fSDimitry Andric   // writing test cases.
381349cc55cSDimitry Andric   Value *VecIndex = Ops[*VecOperand];
382349cc55cSDimitry Andric   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
3835f757f3fSDimitry Andric   if (VecIndex->getType() != VecIntPtrTy) {
3845f757f3fSDimitry Andric     auto *VecIndexC = dyn_cast<Constant>(VecIndex);
3855f757f3fSDimitry Andric     if (!VecIndexC)
386349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
3875f757f3fSDimitry Andric     if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
3885f757f3fSDimitry Andric       VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
3895f757f3fSDimitry Andric     else
3905f757f3fSDimitry Andric       VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
3915f757f3fSDimitry Andric   }
392349cc55cSDimitry Andric 
393bdd1243dSDimitry Andric   // Handle the non-recursive case.  This is what we see if the vectorizer
394bdd1243dSDimitry Andric   // decides to use a scalar IV + vid on demand instead of a vector IV.
395bdd1243dSDimitry Andric   auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
396bdd1243dSDimitry Andric   if (Start) {
397bdd1243dSDimitry Andric     assert(Stride);
398bdd1243dSDimitry Andric     Builder.SetInsertPoint(GEP);
399bdd1243dSDimitry Andric 
400bdd1243dSDimitry Andric     // Replace the vector index with the scalar start and build a scalar GEP.
401bdd1243dSDimitry Andric     Ops[*VecOperand] = Start;
402bdd1243dSDimitry Andric     Type *SourceTy = GEP->getSourceElementType();
403bdd1243dSDimitry Andric     Value *BasePtr =
4045f757f3fSDimitry Andric         Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
405bdd1243dSDimitry Andric 
406bdd1243dSDimitry Andric     // Convert stride to pointer size if needed.
407bdd1243dSDimitry Andric     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
408bdd1243dSDimitry Andric     assert(Stride->getType() == IntPtrTy && "Unexpected type");
409bdd1243dSDimitry Andric 
410bdd1243dSDimitry Andric     // Scale the stride by the size of the indexed type.
411bdd1243dSDimitry Andric     if (TypeScale != 1)
412bdd1243dSDimitry Andric       Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
413bdd1243dSDimitry Andric 
414bdd1243dSDimitry Andric     auto P = std::make_pair(BasePtr, Stride);
415bdd1243dSDimitry Andric     StridedAddrs[GEP] = P;
416bdd1243dSDimitry Andric     return P;
417bdd1243dSDimitry Andric   }
418bdd1243dSDimitry Andric 
419bdd1243dSDimitry Andric   // Make sure we're in a loop and that has a pre-header and a single latch.
420bdd1243dSDimitry Andric   Loop *L = LI->getLoopFor(GEP->getParent());
421bdd1243dSDimitry Andric   if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
422bdd1243dSDimitry Andric     return std::make_pair(nullptr, nullptr);
423bdd1243dSDimitry Andric 
424349cc55cSDimitry Andric   BinaryOperator *Inc;
425349cc55cSDimitry Andric   PHINode *BasePhi;
426349cc55cSDimitry Andric   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
427349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
428349cc55cSDimitry Andric 
429349cc55cSDimitry Andric   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
430349cc55cSDimitry Andric   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
431349cc55cSDimitry Andric   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
432349cc55cSDimitry Andric          "Expected one operand of phi to be Inc");
433349cc55cSDimitry Andric 
434349cc55cSDimitry Andric   Builder.SetInsertPoint(GEP);
435349cc55cSDimitry Andric 
436349cc55cSDimitry Andric   // Replace the vector index with the scalar phi and build a scalar GEP.
437349cc55cSDimitry Andric   Ops[*VecOperand] = BasePhi;
438349cc55cSDimitry Andric   Type *SourceTy = GEP->getSourceElementType();
439349cc55cSDimitry Andric   Value *BasePtr =
4405f757f3fSDimitry Andric       Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
441349cc55cSDimitry Andric 
442349cc55cSDimitry Andric   // Final adjustments to stride should go in the start block.
443349cc55cSDimitry Andric   Builder.SetInsertPoint(
444349cc55cSDimitry Andric       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
445349cc55cSDimitry Andric 
446349cc55cSDimitry Andric   // Convert stride to pointer size if needed.
447349cc55cSDimitry Andric   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
448349cc55cSDimitry Andric   assert(Stride->getType() == IntPtrTy && "Unexpected type");
449349cc55cSDimitry Andric 
450349cc55cSDimitry Andric   // Scale the stride by the size of the indexed type.
451349cc55cSDimitry Andric   if (TypeScale != 1)
452349cc55cSDimitry Andric     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
453349cc55cSDimitry Andric 
45481ad6265SDimitry Andric   auto P = std::make_pair(BasePtr, Stride);
45581ad6265SDimitry Andric   StridedAddrs[GEP] = P;
45681ad6265SDimitry Andric   return P;
457349cc55cSDimitry Andric }
458349cc55cSDimitry Andric 
459349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
460349cc55cSDimitry Andric                                                            Type *DataType,
461349cc55cSDimitry Andric                                                            Value *Ptr,
462349cc55cSDimitry Andric                                                            Value *AlignOp) {
463349cc55cSDimitry Andric   // Make sure the operation will be supported by the backend.
46406c3fb27SDimitry Andric   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
46506c3fb27SDimitry Andric   EVT DataTypeVT = TLI->getValueType(*DL, DataType);
46606c3fb27SDimitry Andric   if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
46706c3fb27SDimitry Andric     return false;
46806c3fb27SDimitry Andric 
46906c3fb27SDimitry Andric   // FIXME: Let the backend type legalize by splitting/widening?
47006c3fb27SDimitry Andric   if (!TLI->isTypeLegal(DataTypeVT))
471349cc55cSDimitry Andric     return false;
472349cc55cSDimitry Andric 
4735f757f3fSDimitry Andric   // Pointer should be an instruction.
4745f757f3fSDimitry Andric   auto *PtrI = dyn_cast<Instruction>(Ptr);
4755f757f3fSDimitry Andric   if (!PtrI)
476349cc55cSDimitry Andric     return false;
477349cc55cSDimitry Andric 
4785f757f3fSDimitry Andric   LLVMContext &Ctx = PtrI->getContext();
47906c3fb27SDimitry Andric   IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
4805f757f3fSDimitry Andric   Builder.SetInsertPoint(PtrI);
481349cc55cSDimitry Andric 
482349cc55cSDimitry Andric   Value *BasePtr, *Stride;
4835f757f3fSDimitry Andric   std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
484349cc55cSDimitry Andric   if (!BasePtr)
485349cc55cSDimitry Andric     return false;
486349cc55cSDimitry Andric   assert(Stride != nullptr);
487349cc55cSDimitry Andric 
488349cc55cSDimitry Andric   Builder.SetInsertPoint(II);
489349cc55cSDimitry Andric 
490349cc55cSDimitry Andric   CallInst *Call;
491349cc55cSDimitry Andric   if (II->getIntrinsicID() == Intrinsic::masked_gather)
492349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
493349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_load,
494349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
495349cc55cSDimitry Andric         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
496349cc55cSDimitry Andric   else
497349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
498349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_store,
499349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
500349cc55cSDimitry Andric         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
501349cc55cSDimitry Andric 
502349cc55cSDimitry Andric   Call->takeName(II);
503349cc55cSDimitry Andric   II->replaceAllUsesWith(Call);
504349cc55cSDimitry Andric   II->eraseFromParent();
505349cc55cSDimitry Andric 
5065f757f3fSDimitry Andric   if (PtrI->use_empty())
5075f757f3fSDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(PtrI);
508349cc55cSDimitry Andric 
509349cc55cSDimitry Andric   return true;
510349cc55cSDimitry Andric }
511349cc55cSDimitry Andric 
512349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
513349cc55cSDimitry Andric   if (skipFunction(F))
514349cc55cSDimitry Andric     return false;
515349cc55cSDimitry Andric 
516349cc55cSDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
517349cc55cSDimitry Andric   auto &TM = TPC.getTM<RISCVTargetMachine>();
518349cc55cSDimitry Andric   ST = &TM.getSubtarget<RISCVSubtarget>(F);
519349cc55cSDimitry Andric   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
520349cc55cSDimitry Andric     return false;
521349cc55cSDimitry Andric 
522349cc55cSDimitry Andric   TLI = ST->getTargetLowering();
523349cc55cSDimitry Andric   DL = &F.getParent()->getDataLayout();
524349cc55cSDimitry Andric   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
525349cc55cSDimitry Andric 
52681ad6265SDimitry Andric   StridedAddrs.clear();
52781ad6265SDimitry Andric 
528349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Gathers;
529349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Scatters;
530349cc55cSDimitry Andric 
531349cc55cSDimitry Andric   bool Changed = false;
532349cc55cSDimitry Andric 
533349cc55cSDimitry Andric   for (BasicBlock &BB : F) {
534349cc55cSDimitry Andric     for (Instruction &I : BB) {
535349cc55cSDimitry Andric       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
536bdd1243dSDimitry Andric       if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
537349cc55cSDimitry Andric         Gathers.push_back(II);
538bdd1243dSDimitry Andric       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
539349cc55cSDimitry Andric         Scatters.push_back(II);
540349cc55cSDimitry Andric       }
541349cc55cSDimitry Andric     }
542349cc55cSDimitry Andric   }
543349cc55cSDimitry Andric 
544349cc55cSDimitry Andric   // Rewrite gather/scatter to form strided load/store if possible.
545349cc55cSDimitry Andric   for (auto *II : Gathers)
546349cc55cSDimitry Andric     Changed |= tryCreateStridedLoadStore(
547349cc55cSDimitry Andric         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
548349cc55cSDimitry Andric   for (auto *II : Scatters)
549349cc55cSDimitry Andric     Changed |=
550349cc55cSDimitry Andric         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
551349cc55cSDimitry Andric                                   II->getArgOperand(1), II->getArgOperand(2));
552349cc55cSDimitry Andric 
553349cc55cSDimitry Andric   // Remove any dead phis.
554349cc55cSDimitry Andric   while (!MaybeDeadPHIs.empty()) {
555349cc55cSDimitry Andric     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
556349cc55cSDimitry Andric       RecursivelyDeleteDeadPHINode(Phi);
557349cc55cSDimitry Andric   }
558349cc55cSDimitry Andric 
559349cc55cSDimitry Andric   return Changed;
560349cc55cSDimitry Andric }
561