xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 349cc55c9796c4596a5b9904cd3281af295f878f)
1*349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2*349cc55cSDimitry Andric //
3*349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*349cc55cSDimitry Andric //
7*349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
8*349cc55cSDimitry Andric //
9*349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to
10*349cc55cSDimitry Andric // RISCV intrinsics.
11*349cc55cSDimitry Andric //
12*349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
13*349cc55cSDimitry Andric 
14*349cc55cSDimitry Andric #include "RISCV.h"
15*349cc55cSDimitry Andric #include "RISCVTargetMachine.h"
16*349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
17*349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
18*349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h"
19*349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
20*349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h"
21*349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h"
22*349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
23*349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h"
24*349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
25*349cc55cSDimitry Andric 
26*349cc55cSDimitry Andric using namespace llvm;
27*349cc55cSDimitry Andric 
28*349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29*349cc55cSDimitry Andric 
30*349cc55cSDimitry Andric namespace {
31*349cc55cSDimitry Andric 
32*349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass {
33*349cc55cSDimitry Andric   const RISCVSubtarget *ST = nullptr;
34*349cc55cSDimitry Andric   const RISCVTargetLowering *TLI = nullptr;
35*349cc55cSDimitry Andric   LoopInfo *LI = nullptr;
36*349cc55cSDimitry Andric   const DataLayout *DL = nullptr;
37*349cc55cSDimitry Andric 
38*349cc55cSDimitry Andric   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39*349cc55cSDimitry Andric 
40*349cc55cSDimitry Andric public:
41*349cc55cSDimitry Andric   static char ID; // Pass identification, replacement for typeid
42*349cc55cSDimitry Andric 
43*349cc55cSDimitry Andric   RISCVGatherScatterLowering() : FunctionPass(ID) {}
44*349cc55cSDimitry Andric 
45*349cc55cSDimitry Andric   bool runOnFunction(Function &F) override;
46*349cc55cSDimitry Andric 
47*349cc55cSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
48*349cc55cSDimitry Andric     AU.setPreservesCFG();
49*349cc55cSDimitry Andric     AU.addRequired<TargetPassConfig>();
50*349cc55cSDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
51*349cc55cSDimitry Andric   }
52*349cc55cSDimitry Andric 
53*349cc55cSDimitry Andric   StringRef getPassName() const override {
54*349cc55cSDimitry Andric     return "RISCV gather/scatter lowering";
55*349cc55cSDimitry Andric   }
56*349cc55cSDimitry Andric 
57*349cc55cSDimitry Andric private:
58*349cc55cSDimitry Andric   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
59*349cc55cSDimitry Andric 
60*349cc55cSDimitry Andric   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
61*349cc55cSDimitry Andric                                  Value *AlignOp);
62*349cc55cSDimitry Andric 
63*349cc55cSDimitry Andric   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
64*349cc55cSDimitry Andric                                                      IRBuilder<> &Builder);
65*349cc55cSDimitry Andric 
66*349cc55cSDimitry Andric   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
67*349cc55cSDimitry Andric                               PHINode *&BasePtr, BinaryOperator *&Inc,
68*349cc55cSDimitry Andric                               IRBuilder<> &Builder);
69*349cc55cSDimitry Andric };
70*349cc55cSDimitry Andric 
71*349cc55cSDimitry Andric } // end anonymous namespace
72*349cc55cSDimitry Andric 
73*349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0;
74*349cc55cSDimitry Andric 
75*349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
76*349cc55cSDimitry Andric                 "RISCV gather/scatter lowering pass", false, false)
77*349cc55cSDimitry Andric 
78*349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
79*349cc55cSDimitry Andric   return new RISCVGatherScatterLowering();
80*349cc55cSDimitry Andric }
81*349cc55cSDimitry Andric 
82*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
83*349cc55cSDimitry Andric                                                          Value *AlignOp) {
84*349cc55cSDimitry Andric   Type *ScalarType = DataType->getScalarType();
85*349cc55cSDimitry Andric   if (!TLI->isLegalElementTypeForRVV(ScalarType))
86*349cc55cSDimitry Andric     return false;
87*349cc55cSDimitry Andric 
88*349cc55cSDimitry Andric   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
89*349cc55cSDimitry Andric   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
90*349cc55cSDimitry Andric     return false;
91*349cc55cSDimitry Andric 
92*349cc55cSDimitry Andric   // FIXME: Let the backend type legalize by splitting/widening?
93*349cc55cSDimitry Andric   EVT DataVT = TLI->getValueType(*DL, DataType);
94*349cc55cSDimitry Andric   if (!TLI->isTypeLegal(DataVT))
95*349cc55cSDimitry Andric     return false;
96*349cc55cSDimitry Andric 
97*349cc55cSDimitry Andric   return true;
98*349cc55cSDimitry Andric }
99*349cc55cSDimitry Andric 
100*349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride?
101*349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
102*349cc55cSDimitry Andric   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
103*349cc55cSDimitry Andric 
104*349cc55cSDimitry Andric   // Check that the start value is a strided constant.
105*349cc55cSDimitry Andric   auto *StartVal =
106*349cc55cSDimitry Andric       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
107*349cc55cSDimitry Andric   if (!StartVal)
108*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
109*349cc55cSDimitry Andric   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
110*349cc55cSDimitry Andric   ConstantInt *Prev = StartVal;
111*349cc55cSDimitry Andric   for (unsigned i = 1; i != NumElts; ++i) {
112*349cc55cSDimitry Andric     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
113*349cc55cSDimitry Andric     if (!C)
114*349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
115*349cc55cSDimitry Andric 
116*349cc55cSDimitry Andric     APInt LocalStride = C->getValue() - Prev->getValue();
117*349cc55cSDimitry Andric     if (i == 1)
118*349cc55cSDimitry Andric       StrideVal = LocalStride;
119*349cc55cSDimitry Andric     else if (StrideVal != LocalStride)
120*349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
121*349cc55cSDimitry Andric 
122*349cc55cSDimitry Andric     Prev = C;
123*349cc55cSDimitry Andric   }
124*349cc55cSDimitry Andric 
125*349cc55cSDimitry Andric   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
126*349cc55cSDimitry Andric 
127*349cc55cSDimitry Andric   return std::make_pair(StartVal, Stride);
128*349cc55cSDimitry Andric }
129*349cc55cSDimitry Andric 
130*349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided
131*349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion.
132*349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the
133*349cc55cSDimitry Andric // arithmetic out of the loop.
134*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
135*349cc55cSDimitry Andric                                                         Value *&Stride,
136*349cc55cSDimitry Andric                                                         PHINode *&BasePtr,
137*349cc55cSDimitry Andric                                                         BinaryOperator *&Inc,
138*349cc55cSDimitry Andric                                                         IRBuilder<> &Builder) {
139*349cc55cSDimitry Andric   // Our base case is a Phi.
140*349cc55cSDimitry Andric   if (auto *Phi = dyn_cast<PHINode>(Index)) {
141*349cc55cSDimitry Andric     // A phi node we want to perform this function on should be from the
142*349cc55cSDimitry Andric     // loop header.
143*349cc55cSDimitry Andric     if (Phi->getParent() != L->getHeader())
144*349cc55cSDimitry Andric       return false;
145*349cc55cSDimitry Andric 
146*349cc55cSDimitry Andric     Value *Step, *Start;
147*349cc55cSDimitry Andric     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
148*349cc55cSDimitry Andric         Inc->getOpcode() != Instruction::Add)
149*349cc55cSDimitry Andric       return false;
150*349cc55cSDimitry Andric     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
151*349cc55cSDimitry Andric     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
152*349cc55cSDimitry Andric     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
153*349cc55cSDimitry Andric            "Expected one operand of phi to be Inc");
154*349cc55cSDimitry Andric 
155*349cc55cSDimitry Andric     // Only proceed if the step is loop invariant.
156*349cc55cSDimitry Andric     if (!L->isLoopInvariant(Step))
157*349cc55cSDimitry Andric       return false;
158*349cc55cSDimitry Andric 
159*349cc55cSDimitry Andric     // Step should be a splat.
160*349cc55cSDimitry Andric     Step = getSplatValue(Step);
161*349cc55cSDimitry Andric     if (!Step)
162*349cc55cSDimitry Andric       return false;
163*349cc55cSDimitry Andric 
164*349cc55cSDimitry Andric     // Start should be a strided constant.
165*349cc55cSDimitry Andric     auto *StartC = dyn_cast<Constant>(Start);
166*349cc55cSDimitry Andric     if (!StartC)
167*349cc55cSDimitry Andric       return false;
168*349cc55cSDimitry Andric 
169*349cc55cSDimitry Andric     std::tie(Start, Stride) = matchStridedConstant(StartC);
170*349cc55cSDimitry Andric     if (!Start)
171*349cc55cSDimitry Andric       return false;
172*349cc55cSDimitry Andric     assert(Stride != nullptr);
173*349cc55cSDimitry Andric 
174*349cc55cSDimitry Andric     // Build scalar phi and increment.
175*349cc55cSDimitry Andric     BasePtr =
176*349cc55cSDimitry Andric         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
177*349cc55cSDimitry Andric     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
178*349cc55cSDimitry Andric                                     Inc);
179*349cc55cSDimitry Andric     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
180*349cc55cSDimitry Andric     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
181*349cc55cSDimitry Andric 
182*349cc55cSDimitry Andric     // Note that this Phi might be eligible for removal.
183*349cc55cSDimitry Andric     MaybeDeadPHIs.push_back(Phi);
184*349cc55cSDimitry Andric     return true;
185*349cc55cSDimitry Andric   }
186*349cc55cSDimitry Andric 
187*349cc55cSDimitry Andric   // Otherwise look for binary operator.
188*349cc55cSDimitry Andric   auto *BO = dyn_cast<BinaryOperator>(Index);
189*349cc55cSDimitry Andric   if (!BO)
190*349cc55cSDimitry Andric     return false;
191*349cc55cSDimitry Andric 
192*349cc55cSDimitry Andric   if (BO->getOpcode() != Instruction::Add &&
193*349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Or &&
194*349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Mul &&
195*349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Shl)
196*349cc55cSDimitry Andric     return false;
197*349cc55cSDimitry Andric 
198*349cc55cSDimitry Andric   // Only support shift by constant.
199*349cc55cSDimitry Andric   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
200*349cc55cSDimitry Andric     return false;
201*349cc55cSDimitry Andric 
202*349cc55cSDimitry Andric   // We need to be able to treat Or as Add.
203*349cc55cSDimitry Andric   if (BO->getOpcode() == Instruction::Or &&
204*349cc55cSDimitry Andric       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
205*349cc55cSDimitry Andric     return false;
206*349cc55cSDimitry Andric 
207*349cc55cSDimitry Andric   // We should have one operand in the loop and one splat.
208*349cc55cSDimitry Andric   Value *OtherOp;
209*349cc55cSDimitry Andric   if (isa<Instruction>(BO->getOperand(0)) &&
210*349cc55cSDimitry Andric       L->contains(cast<Instruction>(BO->getOperand(0)))) {
211*349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(0));
212*349cc55cSDimitry Andric     OtherOp = BO->getOperand(1);
213*349cc55cSDimitry Andric   } else if (isa<Instruction>(BO->getOperand(1)) &&
214*349cc55cSDimitry Andric              L->contains(cast<Instruction>(BO->getOperand(1)))) {
215*349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(1));
216*349cc55cSDimitry Andric     OtherOp = BO->getOperand(0);
217*349cc55cSDimitry Andric   } else {
218*349cc55cSDimitry Andric     return false;
219*349cc55cSDimitry Andric   }
220*349cc55cSDimitry Andric 
221*349cc55cSDimitry Andric   // Make sure other op is loop invariant.
222*349cc55cSDimitry Andric   if (!L->isLoopInvariant(OtherOp))
223*349cc55cSDimitry Andric     return false;
224*349cc55cSDimitry Andric 
225*349cc55cSDimitry Andric   // Make sure we have a splat.
226*349cc55cSDimitry Andric   Value *SplatOp = getSplatValue(OtherOp);
227*349cc55cSDimitry Andric   if (!SplatOp)
228*349cc55cSDimitry Andric     return false;
229*349cc55cSDimitry Andric 
230*349cc55cSDimitry Andric   // Recurse up the use-def chain.
231*349cc55cSDimitry Andric   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
232*349cc55cSDimitry Andric     return false;
233*349cc55cSDimitry Andric 
234*349cc55cSDimitry Andric   // Locate the Step and Start values from the recurrence.
235*349cc55cSDimitry Andric   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
236*349cc55cSDimitry Andric   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
237*349cc55cSDimitry Andric   Value *Step = Inc->getOperand(StepIndex);
238*349cc55cSDimitry Andric   Value *Start = BasePtr->getOperand(StartBlock);
239*349cc55cSDimitry Andric 
240*349cc55cSDimitry Andric   // We need to adjust the start value in the preheader.
241*349cc55cSDimitry Andric   Builder.SetInsertPoint(
242*349cc55cSDimitry Andric       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
243*349cc55cSDimitry Andric   Builder.SetCurrentDebugLocation(DebugLoc());
244*349cc55cSDimitry Andric 
245*349cc55cSDimitry Andric   switch (BO->getOpcode()) {
246*349cc55cSDimitry Andric   default:
247*349cc55cSDimitry Andric     llvm_unreachable("Unexpected opcode!");
248*349cc55cSDimitry Andric   case Instruction::Add:
249*349cc55cSDimitry Andric   case Instruction::Or: {
250*349cc55cSDimitry Andric     // An add only affects the start value. It's ok to do this for Or because
251*349cc55cSDimitry Andric     // we already checked that there are no common set bits.
252*349cc55cSDimitry Andric 
253*349cc55cSDimitry Andric     // If the start value is Zero, just take the SplatOp.
254*349cc55cSDimitry Andric     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
255*349cc55cSDimitry Andric       Start = SplatOp;
256*349cc55cSDimitry Andric     else
257*349cc55cSDimitry Andric       Start = Builder.CreateAdd(Start, SplatOp, "start");
258*349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
259*349cc55cSDimitry Andric     break;
260*349cc55cSDimitry Andric   }
261*349cc55cSDimitry Andric   case Instruction::Mul: {
262*349cc55cSDimitry Andric     // If the start is zero we don't need to multiply.
263*349cc55cSDimitry Andric     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
264*349cc55cSDimitry Andric       Start = Builder.CreateMul(Start, SplatOp, "start");
265*349cc55cSDimitry Andric 
266*349cc55cSDimitry Andric     Step = Builder.CreateMul(Step, SplatOp, "step");
267*349cc55cSDimitry Andric 
268*349cc55cSDimitry Andric     // If the Stride is 1 just take the SplatOpt.
269*349cc55cSDimitry Andric     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
270*349cc55cSDimitry Andric       Stride = SplatOp;
271*349cc55cSDimitry Andric     else
272*349cc55cSDimitry Andric       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
273*349cc55cSDimitry Andric     Inc->setOperand(StepIndex, Step);
274*349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
275*349cc55cSDimitry Andric     break;
276*349cc55cSDimitry Andric   }
277*349cc55cSDimitry Andric   case Instruction::Shl: {
278*349cc55cSDimitry Andric     // If the start is zero we don't need to shift.
279*349cc55cSDimitry Andric     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
280*349cc55cSDimitry Andric       Start = Builder.CreateShl(Start, SplatOp, "start");
281*349cc55cSDimitry Andric     Step = Builder.CreateShl(Step, SplatOp, "step");
282*349cc55cSDimitry Andric     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
283*349cc55cSDimitry Andric     Inc->setOperand(StepIndex, Step);
284*349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
285*349cc55cSDimitry Andric     break;
286*349cc55cSDimitry Andric   }
287*349cc55cSDimitry Andric   }
288*349cc55cSDimitry Andric 
289*349cc55cSDimitry Andric   return true;
290*349cc55cSDimitry Andric }
291*349cc55cSDimitry Andric 
292*349cc55cSDimitry Andric std::pair<Value *, Value *>
293*349cc55cSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
294*349cc55cSDimitry Andric                                                    IRBuilder<> &Builder) {
295*349cc55cSDimitry Andric 
296*349cc55cSDimitry Andric   SmallVector<Value *, 2> Ops(GEP->operands());
297*349cc55cSDimitry Andric 
298*349cc55cSDimitry Andric   // Base pointer needs to be a scalar.
299*349cc55cSDimitry Andric   if (Ops[0]->getType()->isVectorTy())
300*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
301*349cc55cSDimitry Andric 
302*349cc55cSDimitry Andric   // Make sure we're in a loop and it is in loop simplify form.
303*349cc55cSDimitry Andric   Loop *L = LI->getLoopFor(GEP->getParent());
304*349cc55cSDimitry Andric   if (!L || !L->isLoopSimplifyForm())
305*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
306*349cc55cSDimitry Andric 
307*349cc55cSDimitry Andric   Optional<unsigned> VecOperand;
308*349cc55cSDimitry Andric   unsigned TypeScale = 0;
309*349cc55cSDimitry Andric 
310*349cc55cSDimitry Andric   // Look for a vector operand and scale.
311*349cc55cSDimitry Andric   gep_type_iterator GTI = gep_type_begin(GEP);
312*349cc55cSDimitry Andric   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
313*349cc55cSDimitry Andric     if (!Ops[i]->getType()->isVectorTy())
314*349cc55cSDimitry Andric       continue;
315*349cc55cSDimitry Andric 
316*349cc55cSDimitry Andric     if (VecOperand)
317*349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
318*349cc55cSDimitry Andric 
319*349cc55cSDimitry Andric     VecOperand = i;
320*349cc55cSDimitry Andric 
321*349cc55cSDimitry Andric     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
322*349cc55cSDimitry Andric     if (TS.isScalable())
323*349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
324*349cc55cSDimitry Andric 
325*349cc55cSDimitry Andric     TypeScale = TS.getFixedSize();
326*349cc55cSDimitry Andric   }
327*349cc55cSDimitry Andric 
328*349cc55cSDimitry Andric   // We need to find a vector index to simplify.
329*349cc55cSDimitry Andric   if (!VecOperand)
330*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
331*349cc55cSDimitry Andric 
332*349cc55cSDimitry Andric   // We can't extract the stride if the arithmetic is done at a different size
333*349cc55cSDimitry Andric   // than the pointer type. Adding the stride later may not wrap correctly.
334*349cc55cSDimitry Andric   // Technically we could handle wider indices, but I don't expect that in
335*349cc55cSDimitry Andric   // practice.
336*349cc55cSDimitry Andric   Value *VecIndex = Ops[*VecOperand];
337*349cc55cSDimitry Andric   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
338*349cc55cSDimitry Andric   if (VecIndex->getType() != VecIntPtrTy)
339*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
340*349cc55cSDimitry Andric 
341*349cc55cSDimitry Andric   Value *Stride;
342*349cc55cSDimitry Andric   BinaryOperator *Inc;
343*349cc55cSDimitry Andric   PHINode *BasePhi;
344*349cc55cSDimitry Andric   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
345*349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
346*349cc55cSDimitry Andric 
347*349cc55cSDimitry Andric   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
348*349cc55cSDimitry Andric   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
349*349cc55cSDimitry Andric   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
350*349cc55cSDimitry Andric          "Expected one operand of phi to be Inc");
351*349cc55cSDimitry Andric 
352*349cc55cSDimitry Andric   Builder.SetInsertPoint(GEP);
353*349cc55cSDimitry Andric 
354*349cc55cSDimitry Andric   // Replace the vector index with the scalar phi and build a scalar GEP.
355*349cc55cSDimitry Andric   Ops[*VecOperand] = BasePhi;
356*349cc55cSDimitry Andric   Type *SourceTy = GEP->getSourceElementType();
357*349cc55cSDimitry Andric   Value *BasePtr =
358*349cc55cSDimitry Andric       Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
359*349cc55cSDimitry Andric 
360*349cc55cSDimitry Andric   // Cast the GEP to an i8*.
361*349cc55cSDimitry Andric   LLVMContext &Ctx = GEP->getContext();
362*349cc55cSDimitry Andric   Type *I8PtrTy =
363*349cc55cSDimitry Andric       Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace());
364*349cc55cSDimitry Andric   if (BasePtr->getType() != I8PtrTy)
365*349cc55cSDimitry Andric     BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy);
366*349cc55cSDimitry Andric 
367*349cc55cSDimitry Andric   // Final adjustments to stride should go in the start block.
368*349cc55cSDimitry Andric   Builder.SetInsertPoint(
369*349cc55cSDimitry Andric       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
370*349cc55cSDimitry Andric 
371*349cc55cSDimitry Andric   // Convert stride to pointer size if needed.
372*349cc55cSDimitry Andric   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
373*349cc55cSDimitry Andric   assert(Stride->getType() == IntPtrTy && "Unexpected type");
374*349cc55cSDimitry Andric 
375*349cc55cSDimitry Andric   // Scale the stride by the size of the indexed type.
376*349cc55cSDimitry Andric   if (TypeScale != 1)
377*349cc55cSDimitry Andric     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
378*349cc55cSDimitry Andric 
379*349cc55cSDimitry Andric   return std::make_pair(BasePtr, Stride);
380*349cc55cSDimitry Andric }
381*349cc55cSDimitry Andric 
382*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
383*349cc55cSDimitry Andric                                                            Type *DataType,
384*349cc55cSDimitry Andric                                                            Value *Ptr,
385*349cc55cSDimitry Andric                                                            Value *AlignOp) {
386*349cc55cSDimitry Andric   // Make sure the operation will be supported by the backend.
387*349cc55cSDimitry Andric   if (!isLegalTypeAndAlignment(DataType, AlignOp))
388*349cc55cSDimitry Andric     return false;
389*349cc55cSDimitry Andric 
390*349cc55cSDimitry Andric   // Pointer should be a GEP.
391*349cc55cSDimitry Andric   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
392*349cc55cSDimitry Andric   if (!GEP)
393*349cc55cSDimitry Andric     return false;
394*349cc55cSDimitry Andric 
395*349cc55cSDimitry Andric   IRBuilder<> Builder(GEP);
396*349cc55cSDimitry Andric 
397*349cc55cSDimitry Andric   Value *BasePtr, *Stride;
398*349cc55cSDimitry Andric   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
399*349cc55cSDimitry Andric   if (!BasePtr)
400*349cc55cSDimitry Andric     return false;
401*349cc55cSDimitry Andric   assert(Stride != nullptr);
402*349cc55cSDimitry Andric 
403*349cc55cSDimitry Andric   Builder.SetInsertPoint(II);
404*349cc55cSDimitry Andric 
405*349cc55cSDimitry Andric   CallInst *Call;
406*349cc55cSDimitry Andric   if (II->getIntrinsicID() == Intrinsic::masked_gather)
407*349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
408*349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_load,
409*349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
410*349cc55cSDimitry Andric         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
411*349cc55cSDimitry Andric   else
412*349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
413*349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_store,
414*349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
415*349cc55cSDimitry Andric         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
416*349cc55cSDimitry Andric 
417*349cc55cSDimitry Andric   Call->takeName(II);
418*349cc55cSDimitry Andric   II->replaceAllUsesWith(Call);
419*349cc55cSDimitry Andric   II->eraseFromParent();
420*349cc55cSDimitry Andric 
421*349cc55cSDimitry Andric   if (GEP->use_empty())
422*349cc55cSDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(GEP);
423*349cc55cSDimitry Andric 
424*349cc55cSDimitry Andric   return true;
425*349cc55cSDimitry Andric }
426*349cc55cSDimitry Andric 
427*349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
428*349cc55cSDimitry Andric   if (skipFunction(F))
429*349cc55cSDimitry Andric     return false;
430*349cc55cSDimitry Andric 
431*349cc55cSDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
432*349cc55cSDimitry Andric   auto &TM = TPC.getTM<RISCVTargetMachine>();
433*349cc55cSDimitry Andric   ST = &TM.getSubtarget<RISCVSubtarget>(F);
434*349cc55cSDimitry Andric   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
435*349cc55cSDimitry Andric     return false;
436*349cc55cSDimitry Andric 
437*349cc55cSDimitry Andric   TLI = ST->getTargetLowering();
438*349cc55cSDimitry Andric   DL = &F.getParent()->getDataLayout();
439*349cc55cSDimitry Andric   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
440*349cc55cSDimitry Andric 
441*349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Gathers;
442*349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Scatters;
443*349cc55cSDimitry Andric 
444*349cc55cSDimitry Andric   bool Changed = false;
445*349cc55cSDimitry Andric 
446*349cc55cSDimitry Andric   for (BasicBlock &BB : F) {
447*349cc55cSDimitry Andric     for (Instruction &I : BB) {
448*349cc55cSDimitry Andric       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
449*349cc55cSDimitry Andric       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
450*349cc55cSDimitry Andric           isa<FixedVectorType>(II->getType())) {
451*349cc55cSDimitry Andric         Gathers.push_back(II);
452*349cc55cSDimitry Andric       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
453*349cc55cSDimitry Andric                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
454*349cc55cSDimitry Andric         Scatters.push_back(II);
455*349cc55cSDimitry Andric       }
456*349cc55cSDimitry Andric     }
457*349cc55cSDimitry Andric   }
458*349cc55cSDimitry Andric 
459*349cc55cSDimitry Andric   // Rewrite gather/scatter to form strided load/store if possible.
460*349cc55cSDimitry Andric   for (auto *II : Gathers)
461*349cc55cSDimitry Andric     Changed |= tryCreateStridedLoadStore(
462*349cc55cSDimitry Andric         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
463*349cc55cSDimitry Andric   for (auto *II : Scatters)
464*349cc55cSDimitry Andric     Changed |=
465*349cc55cSDimitry Andric         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
466*349cc55cSDimitry Andric                                   II->getArgOperand(1), II->getArgOperand(2));
467*349cc55cSDimitry Andric 
468*349cc55cSDimitry Andric   // Remove any dead phis.
469*349cc55cSDimitry Andric   while (!MaybeDeadPHIs.empty()) {
470*349cc55cSDimitry Andric     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
471*349cc55cSDimitry Andric       RecursivelyDeleteDeadPHINode(Phi);
472*349cc55cSDimitry Andric   }
473*349cc55cSDimitry Andric 
474*349cc55cSDimitry Andric   return Changed;
475*349cc55cSDimitry Andric }
476