xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 04eeddc0aa8e0a417a16eaf9d7d095207f4a8623)
1349cc55cSDimitry Andric //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2349cc55cSDimitry Andric //
3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6349cc55cSDimitry Andric //
7349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
8349cc55cSDimitry Andric //
9349cc55cSDimitry Andric // This pass custom lowers llvm.gather and llvm.scatter instructions to
10349cc55cSDimitry Andric // RISCV intrinsics.
11349cc55cSDimitry Andric //
12349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
13349cc55cSDimitry Andric 
14349cc55cSDimitry Andric #include "RISCV.h"
15349cc55cSDimitry Andric #include "RISCVTargetMachine.h"
16349cc55cSDimitry Andric #include "llvm/Analysis/LoopInfo.h"
17349cc55cSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
18349cc55cSDimitry Andric #include "llvm/Analysis/VectorUtils.h"
19349cc55cSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
20349cc55cSDimitry Andric #include "llvm/IR/GetElementPtrTypeIterator.h"
21349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h"
22349cc55cSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
23349cc55cSDimitry Andric #include "llvm/IR/IntrinsicsRISCV.h"
24349cc55cSDimitry Andric #include "llvm/Transforms/Utils/Local.h"
25349cc55cSDimitry Andric 
26349cc55cSDimitry Andric using namespace llvm;
27349cc55cSDimitry Andric 
28349cc55cSDimitry Andric #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29349cc55cSDimitry Andric 
30349cc55cSDimitry Andric namespace {
31349cc55cSDimitry Andric 
32349cc55cSDimitry Andric class RISCVGatherScatterLowering : public FunctionPass {
33349cc55cSDimitry Andric   const RISCVSubtarget *ST = nullptr;
34349cc55cSDimitry Andric   const RISCVTargetLowering *TLI = nullptr;
35349cc55cSDimitry Andric   LoopInfo *LI = nullptr;
36349cc55cSDimitry Andric   const DataLayout *DL = nullptr;
37349cc55cSDimitry Andric 
38349cc55cSDimitry Andric   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39349cc55cSDimitry Andric 
40349cc55cSDimitry Andric public:
41349cc55cSDimitry Andric   static char ID; // Pass identification, replacement for typeid
42349cc55cSDimitry Andric 
43349cc55cSDimitry Andric   RISCVGatherScatterLowering() : FunctionPass(ID) {}
44349cc55cSDimitry Andric 
45349cc55cSDimitry Andric   bool runOnFunction(Function &F) override;
46349cc55cSDimitry Andric 
47349cc55cSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
48349cc55cSDimitry Andric     AU.setPreservesCFG();
49349cc55cSDimitry Andric     AU.addRequired<TargetPassConfig>();
50349cc55cSDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
51349cc55cSDimitry Andric   }
52349cc55cSDimitry Andric 
53349cc55cSDimitry Andric   StringRef getPassName() const override {
54349cc55cSDimitry Andric     return "RISCV gather/scatter lowering";
55349cc55cSDimitry Andric   }
56349cc55cSDimitry Andric 
57349cc55cSDimitry Andric private:
58349cc55cSDimitry Andric   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
59349cc55cSDimitry Andric 
60349cc55cSDimitry Andric   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
61349cc55cSDimitry Andric                                  Value *AlignOp);
62349cc55cSDimitry Andric 
63349cc55cSDimitry Andric   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
64349cc55cSDimitry Andric                                                      IRBuilder<> &Builder);
65349cc55cSDimitry Andric 
66349cc55cSDimitry Andric   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
67349cc55cSDimitry Andric                               PHINode *&BasePtr, BinaryOperator *&Inc,
68349cc55cSDimitry Andric                               IRBuilder<> &Builder);
69349cc55cSDimitry Andric };
70349cc55cSDimitry Andric 
71349cc55cSDimitry Andric } // end anonymous namespace
72349cc55cSDimitry Andric 
73349cc55cSDimitry Andric char RISCVGatherScatterLowering::ID = 0;
74349cc55cSDimitry Andric 
75349cc55cSDimitry Andric INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
76349cc55cSDimitry Andric                 "RISCV gather/scatter lowering pass", false, false)
77349cc55cSDimitry Andric 
78349cc55cSDimitry Andric FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
79349cc55cSDimitry Andric   return new RISCVGatherScatterLowering();
80349cc55cSDimitry Andric }
81349cc55cSDimitry Andric 
82349cc55cSDimitry Andric bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
83349cc55cSDimitry Andric                                                          Value *AlignOp) {
84349cc55cSDimitry Andric   Type *ScalarType = DataType->getScalarType();
85349cc55cSDimitry Andric   if (!TLI->isLegalElementTypeForRVV(ScalarType))
86349cc55cSDimitry Andric     return false;
87349cc55cSDimitry Andric 
88349cc55cSDimitry Andric   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
89349cc55cSDimitry Andric   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
90349cc55cSDimitry Andric     return false;
91349cc55cSDimitry Andric 
92349cc55cSDimitry Andric   // FIXME: Let the backend type legalize by splitting/widening?
93349cc55cSDimitry Andric   EVT DataVT = TLI->getValueType(*DL, DataType);
94349cc55cSDimitry Andric   if (!TLI->isTypeLegal(DataVT))
95349cc55cSDimitry Andric     return false;
96349cc55cSDimitry Andric 
97349cc55cSDimitry Andric   return true;
98349cc55cSDimitry Andric }
99349cc55cSDimitry Andric 
100349cc55cSDimitry Andric // TODO: Should we consider the mask when looking for a stride?
101349cc55cSDimitry Andric static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
102349cc55cSDimitry Andric   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
103349cc55cSDimitry Andric 
104349cc55cSDimitry Andric   // Check that the start value is a strided constant.
105349cc55cSDimitry Andric   auto *StartVal =
106349cc55cSDimitry Andric       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
107349cc55cSDimitry Andric   if (!StartVal)
108349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
109349cc55cSDimitry Andric   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
110349cc55cSDimitry Andric   ConstantInt *Prev = StartVal;
111349cc55cSDimitry Andric   for (unsigned i = 1; i != NumElts; ++i) {
112349cc55cSDimitry Andric     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
113349cc55cSDimitry Andric     if (!C)
114349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
115349cc55cSDimitry Andric 
116349cc55cSDimitry Andric     APInt LocalStride = C->getValue() - Prev->getValue();
117349cc55cSDimitry Andric     if (i == 1)
118349cc55cSDimitry Andric       StrideVal = LocalStride;
119349cc55cSDimitry Andric     else if (StrideVal != LocalStride)
120349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
121349cc55cSDimitry Andric 
122349cc55cSDimitry Andric     Prev = C;
123349cc55cSDimitry Andric   }
124349cc55cSDimitry Andric 
125349cc55cSDimitry Andric   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
126349cc55cSDimitry Andric 
127349cc55cSDimitry Andric   return std::make_pair(StartVal, Stride);
128349cc55cSDimitry Andric }
129349cc55cSDimitry Andric 
130*04eeddc0SDimitry Andric static std::pair<Value *, Value *> matchStridedStart(Value *Start,
131*04eeddc0SDimitry Andric                                                      IRBuilder<> &Builder) {
132*04eeddc0SDimitry Andric   // Base case, start is a strided constant.
133*04eeddc0SDimitry Andric   auto *StartC = dyn_cast<Constant>(Start);
134*04eeddc0SDimitry Andric   if (StartC)
135*04eeddc0SDimitry Andric     return matchStridedConstant(StartC);
136*04eeddc0SDimitry Andric 
137*04eeddc0SDimitry Andric   // Not a constant, maybe it's a strided constant with a splat added to it.
138*04eeddc0SDimitry Andric   auto *BO = dyn_cast<BinaryOperator>(Start);
139*04eeddc0SDimitry Andric   if (!BO || BO->getOpcode() != Instruction::Add)
140*04eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
141*04eeddc0SDimitry Andric 
142*04eeddc0SDimitry Andric   // Look for an operand that is splatted.
143*04eeddc0SDimitry Andric   unsigned OtherIndex = 1;
144*04eeddc0SDimitry Andric   Value *Splat = getSplatValue(BO->getOperand(0));
145*04eeddc0SDimitry Andric   if (!Splat) {
146*04eeddc0SDimitry Andric     Splat = getSplatValue(BO->getOperand(1));
147*04eeddc0SDimitry Andric     OtherIndex = 0;
148*04eeddc0SDimitry Andric   }
149*04eeddc0SDimitry Andric   if (!Splat)
150*04eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
151*04eeddc0SDimitry Andric 
152*04eeddc0SDimitry Andric   Value *Stride;
153*04eeddc0SDimitry Andric   std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
154*04eeddc0SDimitry Andric                                               Builder);
155*04eeddc0SDimitry Andric   if (!Start)
156*04eeddc0SDimitry Andric     return std::make_pair(nullptr, nullptr);
157*04eeddc0SDimitry Andric 
158*04eeddc0SDimitry Andric   // Add the splat value to the start.
159*04eeddc0SDimitry Andric   Builder.SetInsertPoint(BO);
160*04eeddc0SDimitry Andric   Builder.SetCurrentDebugLocation(DebugLoc());
161*04eeddc0SDimitry Andric   Start = Builder.CreateAdd(Start, Splat);
162*04eeddc0SDimitry Andric   return std::make_pair(Start, Stride);
163*04eeddc0SDimitry Andric }
164*04eeddc0SDimitry Andric 
165349cc55cSDimitry Andric // Recursively, walk about the use-def chain until we find a Phi with a strided
166349cc55cSDimitry Andric // start value. Build and update a scalar recurrence as we unwind the recursion.
167349cc55cSDimitry Andric // We also update the Stride as we unwind. Our goal is to move all of the
168349cc55cSDimitry Andric // arithmetic out of the loop.
169349cc55cSDimitry Andric bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
170349cc55cSDimitry Andric                                                         Value *&Stride,
171349cc55cSDimitry Andric                                                         PHINode *&BasePtr,
172349cc55cSDimitry Andric                                                         BinaryOperator *&Inc,
173349cc55cSDimitry Andric                                                         IRBuilder<> &Builder) {
174349cc55cSDimitry Andric   // Our base case is a Phi.
175349cc55cSDimitry Andric   if (auto *Phi = dyn_cast<PHINode>(Index)) {
176349cc55cSDimitry Andric     // A phi node we want to perform this function on should be from the
177349cc55cSDimitry Andric     // loop header.
178349cc55cSDimitry Andric     if (Phi->getParent() != L->getHeader())
179349cc55cSDimitry Andric       return false;
180349cc55cSDimitry Andric 
181349cc55cSDimitry Andric     Value *Step, *Start;
182349cc55cSDimitry Andric     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
183349cc55cSDimitry Andric         Inc->getOpcode() != Instruction::Add)
184349cc55cSDimitry Andric       return false;
185349cc55cSDimitry Andric     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
186349cc55cSDimitry Andric     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
187349cc55cSDimitry Andric     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
188349cc55cSDimitry Andric            "Expected one operand of phi to be Inc");
189349cc55cSDimitry Andric 
190349cc55cSDimitry Andric     // Only proceed if the step is loop invariant.
191349cc55cSDimitry Andric     if (!L->isLoopInvariant(Step))
192349cc55cSDimitry Andric       return false;
193349cc55cSDimitry Andric 
194349cc55cSDimitry Andric     // Step should be a splat.
195349cc55cSDimitry Andric     Step = getSplatValue(Step);
196349cc55cSDimitry Andric     if (!Step)
197349cc55cSDimitry Andric       return false;
198349cc55cSDimitry Andric 
199*04eeddc0SDimitry Andric     std::tie(Start, Stride) = matchStridedStart(Start, Builder);
200349cc55cSDimitry Andric     if (!Start)
201349cc55cSDimitry Andric       return false;
202349cc55cSDimitry Andric     assert(Stride != nullptr);
203349cc55cSDimitry Andric 
204349cc55cSDimitry Andric     // Build scalar phi and increment.
205349cc55cSDimitry Andric     BasePtr =
206349cc55cSDimitry Andric         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
207349cc55cSDimitry Andric     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
208349cc55cSDimitry Andric                                     Inc);
209349cc55cSDimitry Andric     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
210349cc55cSDimitry Andric     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
211349cc55cSDimitry Andric 
212349cc55cSDimitry Andric     // Note that this Phi might be eligible for removal.
213349cc55cSDimitry Andric     MaybeDeadPHIs.push_back(Phi);
214349cc55cSDimitry Andric     return true;
215349cc55cSDimitry Andric   }
216349cc55cSDimitry Andric 
217349cc55cSDimitry Andric   // Otherwise look for binary operator.
218349cc55cSDimitry Andric   auto *BO = dyn_cast<BinaryOperator>(Index);
219349cc55cSDimitry Andric   if (!BO)
220349cc55cSDimitry Andric     return false;
221349cc55cSDimitry Andric 
222349cc55cSDimitry Andric   if (BO->getOpcode() != Instruction::Add &&
223349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Or &&
224349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Mul &&
225349cc55cSDimitry Andric       BO->getOpcode() != Instruction::Shl)
226349cc55cSDimitry Andric     return false;
227349cc55cSDimitry Andric 
228349cc55cSDimitry Andric   // Only support shift by constant.
229349cc55cSDimitry Andric   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
230349cc55cSDimitry Andric     return false;
231349cc55cSDimitry Andric 
232349cc55cSDimitry Andric   // We need to be able to treat Or as Add.
233349cc55cSDimitry Andric   if (BO->getOpcode() == Instruction::Or &&
234349cc55cSDimitry Andric       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
235349cc55cSDimitry Andric     return false;
236349cc55cSDimitry Andric 
237349cc55cSDimitry Andric   // We should have one operand in the loop and one splat.
238349cc55cSDimitry Andric   Value *OtherOp;
239349cc55cSDimitry Andric   if (isa<Instruction>(BO->getOperand(0)) &&
240349cc55cSDimitry Andric       L->contains(cast<Instruction>(BO->getOperand(0)))) {
241349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(0));
242349cc55cSDimitry Andric     OtherOp = BO->getOperand(1);
243349cc55cSDimitry Andric   } else if (isa<Instruction>(BO->getOperand(1)) &&
244349cc55cSDimitry Andric              L->contains(cast<Instruction>(BO->getOperand(1)))) {
245349cc55cSDimitry Andric     Index = cast<Instruction>(BO->getOperand(1));
246349cc55cSDimitry Andric     OtherOp = BO->getOperand(0);
247349cc55cSDimitry Andric   } else {
248349cc55cSDimitry Andric     return false;
249349cc55cSDimitry Andric   }
250349cc55cSDimitry Andric 
251349cc55cSDimitry Andric   // Make sure other op is loop invariant.
252349cc55cSDimitry Andric   if (!L->isLoopInvariant(OtherOp))
253349cc55cSDimitry Andric     return false;
254349cc55cSDimitry Andric 
255349cc55cSDimitry Andric   // Make sure we have a splat.
256349cc55cSDimitry Andric   Value *SplatOp = getSplatValue(OtherOp);
257349cc55cSDimitry Andric   if (!SplatOp)
258349cc55cSDimitry Andric     return false;
259349cc55cSDimitry Andric 
260349cc55cSDimitry Andric   // Recurse up the use-def chain.
261349cc55cSDimitry Andric   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
262349cc55cSDimitry Andric     return false;
263349cc55cSDimitry Andric 
264349cc55cSDimitry Andric   // Locate the Step and Start values from the recurrence.
265349cc55cSDimitry Andric   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
266349cc55cSDimitry Andric   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
267349cc55cSDimitry Andric   Value *Step = Inc->getOperand(StepIndex);
268349cc55cSDimitry Andric   Value *Start = BasePtr->getOperand(StartBlock);
269349cc55cSDimitry Andric 
270349cc55cSDimitry Andric   // We need to adjust the start value in the preheader.
271349cc55cSDimitry Andric   Builder.SetInsertPoint(
272349cc55cSDimitry Andric       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
273349cc55cSDimitry Andric   Builder.SetCurrentDebugLocation(DebugLoc());
274349cc55cSDimitry Andric 
275349cc55cSDimitry Andric   switch (BO->getOpcode()) {
276349cc55cSDimitry Andric   default:
277349cc55cSDimitry Andric     llvm_unreachable("Unexpected opcode!");
278349cc55cSDimitry Andric   case Instruction::Add:
279349cc55cSDimitry Andric   case Instruction::Or: {
280349cc55cSDimitry Andric     // An add only affects the start value. It's ok to do this for Or because
281349cc55cSDimitry Andric     // we already checked that there are no common set bits.
282349cc55cSDimitry Andric 
283349cc55cSDimitry Andric     // If the start value is Zero, just take the SplatOp.
284349cc55cSDimitry Andric     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
285349cc55cSDimitry Andric       Start = SplatOp;
286349cc55cSDimitry Andric     else
287349cc55cSDimitry Andric       Start = Builder.CreateAdd(Start, SplatOp, "start");
288349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
289349cc55cSDimitry Andric     break;
290349cc55cSDimitry Andric   }
291349cc55cSDimitry Andric   case Instruction::Mul: {
292349cc55cSDimitry Andric     // If the start is zero we don't need to multiply.
293349cc55cSDimitry Andric     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
294349cc55cSDimitry Andric       Start = Builder.CreateMul(Start, SplatOp, "start");
295349cc55cSDimitry Andric 
296349cc55cSDimitry Andric     Step = Builder.CreateMul(Step, SplatOp, "step");
297349cc55cSDimitry Andric 
298349cc55cSDimitry Andric     // If the Stride is 1 just take the SplatOpt.
299349cc55cSDimitry Andric     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
300349cc55cSDimitry Andric       Stride = SplatOp;
301349cc55cSDimitry Andric     else
302349cc55cSDimitry Andric       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
303349cc55cSDimitry Andric     Inc->setOperand(StepIndex, Step);
304349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
305349cc55cSDimitry Andric     break;
306349cc55cSDimitry Andric   }
307349cc55cSDimitry Andric   case Instruction::Shl: {
308349cc55cSDimitry Andric     // If the start is zero we don't need to shift.
309349cc55cSDimitry Andric     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
310349cc55cSDimitry Andric       Start = Builder.CreateShl(Start, SplatOp, "start");
311349cc55cSDimitry Andric     Step = Builder.CreateShl(Step, SplatOp, "step");
312349cc55cSDimitry Andric     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
313349cc55cSDimitry Andric     Inc->setOperand(StepIndex, Step);
314349cc55cSDimitry Andric     BasePtr->setIncomingValue(StartBlock, Start);
315349cc55cSDimitry Andric     break;
316349cc55cSDimitry Andric   }
317349cc55cSDimitry Andric   }
318349cc55cSDimitry Andric 
319349cc55cSDimitry Andric   return true;
320349cc55cSDimitry Andric }
321349cc55cSDimitry Andric 
322349cc55cSDimitry Andric std::pair<Value *, Value *>
323349cc55cSDimitry Andric RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
324349cc55cSDimitry Andric                                                    IRBuilder<> &Builder) {
325349cc55cSDimitry Andric 
326349cc55cSDimitry Andric   SmallVector<Value *, 2> Ops(GEP->operands());
327349cc55cSDimitry Andric 
328349cc55cSDimitry Andric   // Base pointer needs to be a scalar.
329349cc55cSDimitry Andric   if (Ops[0]->getType()->isVectorTy())
330349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
331349cc55cSDimitry Andric 
332349cc55cSDimitry Andric   // Make sure we're in a loop and it is in loop simplify form.
333349cc55cSDimitry Andric   Loop *L = LI->getLoopFor(GEP->getParent());
334349cc55cSDimitry Andric   if (!L || !L->isLoopSimplifyForm())
335349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
336349cc55cSDimitry Andric 
337349cc55cSDimitry Andric   Optional<unsigned> VecOperand;
338349cc55cSDimitry Andric   unsigned TypeScale = 0;
339349cc55cSDimitry Andric 
340349cc55cSDimitry Andric   // Look for a vector operand and scale.
341349cc55cSDimitry Andric   gep_type_iterator GTI = gep_type_begin(GEP);
342349cc55cSDimitry Andric   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
343349cc55cSDimitry Andric     if (!Ops[i]->getType()->isVectorTy())
344349cc55cSDimitry Andric       continue;
345349cc55cSDimitry Andric 
346349cc55cSDimitry Andric     if (VecOperand)
347349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
348349cc55cSDimitry Andric 
349349cc55cSDimitry Andric     VecOperand = i;
350349cc55cSDimitry Andric 
351349cc55cSDimitry Andric     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
352349cc55cSDimitry Andric     if (TS.isScalable())
353349cc55cSDimitry Andric       return std::make_pair(nullptr, nullptr);
354349cc55cSDimitry Andric 
355349cc55cSDimitry Andric     TypeScale = TS.getFixedSize();
356349cc55cSDimitry Andric   }
357349cc55cSDimitry Andric 
358349cc55cSDimitry Andric   // We need to find a vector index to simplify.
359349cc55cSDimitry Andric   if (!VecOperand)
360349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
361349cc55cSDimitry Andric 
362349cc55cSDimitry Andric   // We can't extract the stride if the arithmetic is done at a different size
363349cc55cSDimitry Andric   // than the pointer type. Adding the stride later may not wrap correctly.
364349cc55cSDimitry Andric   // Technically we could handle wider indices, but I don't expect that in
365349cc55cSDimitry Andric   // practice.
366349cc55cSDimitry Andric   Value *VecIndex = Ops[*VecOperand];
367349cc55cSDimitry Andric   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
368349cc55cSDimitry Andric   if (VecIndex->getType() != VecIntPtrTy)
369349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
370349cc55cSDimitry Andric 
371349cc55cSDimitry Andric   Value *Stride;
372349cc55cSDimitry Andric   BinaryOperator *Inc;
373349cc55cSDimitry Andric   PHINode *BasePhi;
374349cc55cSDimitry Andric   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
375349cc55cSDimitry Andric     return std::make_pair(nullptr, nullptr);
376349cc55cSDimitry Andric 
377349cc55cSDimitry Andric   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
378349cc55cSDimitry Andric   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
379349cc55cSDimitry Andric   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
380349cc55cSDimitry Andric          "Expected one operand of phi to be Inc");
381349cc55cSDimitry Andric 
382349cc55cSDimitry Andric   Builder.SetInsertPoint(GEP);
383349cc55cSDimitry Andric 
384349cc55cSDimitry Andric   // Replace the vector index with the scalar phi and build a scalar GEP.
385349cc55cSDimitry Andric   Ops[*VecOperand] = BasePhi;
386349cc55cSDimitry Andric   Type *SourceTy = GEP->getSourceElementType();
387349cc55cSDimitry Andric   Value *BasePtr =
388349cc55cSDimitry Andric       Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
389349cc55cSDimitry Andric 
390349cc55cSDimitry Andric   // Cast the GEP to an i8*.
391349cc55cSDimitry Andric   LLVMContext &Ctx = GEP->getContext();
392349cc55cSDimitry Andric   Type *I8PtrTy =
393349cc55cSDimitry Andric       Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace());
394349cc55cSDimitry Andric   if (BasePtr->getType() != I8PtrTy)
395349cc55cSDimitry Andric     BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy);
396349cc55cSDimitry Andric 
397349cc55cSDimitry Andric   // Final adjustments to stride should go in the start block.
398349cc55cSDimitry Andric   Builder.SetInsertPoint(
399349cc55cSDimitry Andric       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
400349cc55cSDimitry Andric 
401349cc55cSDimitry Andric   // Convert stride to pointer size if needed.
402349cc55cSDimitry Andric   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
403349cc55cSDimitry Andric   assert(Stride->getType() == IntPtrTy && "Unexpected type");
404349cc55cSDimitry Andric 
405349cc55cSDimitry Andric   // Scale the stride by the size of the indexed type.
406349cc55cSDimitry Andric   if (TypeScale != 1)
407349cc55cSDimitry Andric     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
408349cc55cSDimitry Andric 
409349cc55cSDimitry Andric   return std::make_pair(BasePtr, Stride);
410349cc55cSDimitry Andric }
411349cc55cSDimitry Andric 
412349cc55cSDimitry Andric bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
413349cc55cSDimitry Andric                                                            Type *DataType,
414349cc55cSDimitry Andric                                                            Value *Ptr,
415349cc55cSDimitry Andric                                                            Value *AlignOp) {
416349cc55cSDimitry Andric   // Make sure the operation will be supported by the backend.
417349cc55cSDimitry Andric   if (!isLegalTypeAndAlignment(DataType, AlignOp))
418349cc55cSDimitry Andric     return false;
419349cc55cSDimitry Andric 
420349cc55cSDimitry Andric   // Pointer should be a GEP.
421349cc55cSDimitry Andric   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
422349cc55cSDimitry Andric   if (!GEP)
423349cc55cSDimitry Andric     return false;
424349cc55cSDimitry Andric 
425349cc55cSDimitry Andric   IRBuilder<> Builder(GEP);
426349cc55cSDimitry Andric 
427349cc55cSDimitry Andric   Value *BasePtr, *Stride;
428349cc55cSDimitry Andric   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
429349cc55cSDimitry Andric   if (!BasePtr)
430349cc55cSDimitry Andric     return false;
431349cc55cSDimitry Andric   assert(Stride != nullptr);
432349cc55cSDimitry Andric 
433349cc55cSDimitry Andric   Builder.SetInsertPoint(II);
434349cc55cSDimitry Andric 
435349cc55cSDimitry Andric   CallInst *Call;
436349cc55cSDimitry Andric   if (II->getIntrinsicID() == Intrinsic::masked_gather)
437349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
438349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_load,
439349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
440349cc55cSDimitry Andric         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
441349cc55cSDimitry Andric   else
442349cc55cSDimitry Andric     Call = Builder.CreateIntrinsic(
443349cc55cSDimitry Andric         Intrinsic::riscv_masked_strided_store,
444349cc55cSDimitry Andric         {DataType, BasePtr->getType(), Stride->getType()},
445349cc55cSDimitry Andric         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
446349cc55cSDimitry Andric 
447349cc55cSDimitry Andric   Call->takeName(II);
448349cc55cSDimitry Andric   II->replaceAllUsesWith(Call);
449349cc55cSDimitry Andric   II->eraseFromParent();
450349cc55cSDimitry Andric 
451349cc55cSDimitry Andric   if (GEP->use_empty())
452349cc55cSDimitry Andric     RecursivelyDeleteTriviallyDeadInstructions(GEP);
453349cc55cSDimitry Andric 
454349cc55cSDimitry Andric   return true;
455349cc55cSDimitry Andric }
456349cc55cSDimitry Andric 
457349cc55cSDimitry Andric bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
458349cc55cSDimitry Andric   if (skipFunction(F))
459349cc55cSDimitry Andric     return false;
460349cc55cSDimitry Andric 
461349cc55cSDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
462349cc55cSDimitry Andric   auto &TM = TPC.getTM<RISCVTargetMachine>();
463349cc55cSDimitry Andric   ST = &TM.getSubtarget<RISCVSubtarget>(F);
464349cc55cSDimitry Andric   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
465349cc55cSDimitry Andric     return false;
466349cc55cSDimitry Andric 
467349cc55cSDimitry Andric   TLI = ST->getTargetLowering();
468349cc55cSDimitry Andric   DL = &F.getParent()->getDataLayout();
469349cc55cSDimitry Andric   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
470349cc55cSDimitry Andric 
471349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Gathers;
472349cc55cSDimitry Andric   SmallVector<IntrinsicInst *, 4> Scatters;
473349cc55cSDimitry Andric 
474349cc55cSDimitry Andric   bool Changed = false;
475349cc55cSDimitry Andric 
476349cc55cSDimitry Andric   for (BasicBlock &BB : F) {
477349cc55cSDimitry Andric     for (Instruction &I : BB) {
478349cc55cSDimitry Andric       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
479349cc55cSDimitry Andric       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
480349cc55cSDimitry Andric           isa<FixedVectorType>(II->getType())) {
481349cc55cSDimitry Andric         Gathers.push_back(II);
482349cc55cSDimitry Andric       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
483349cc55cSDimitry Andric                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
484349cc55cSDimitry Andric         Scatters.push_back(II);
485349cc55cSDimitry Andric       }
486349cc55cSDimitry Andric     }
487349cc55cSDimitry Andric   }
488349cc55cSDimitry Andric 
489349cc55cSDimitry Andric   // Rewrite gather/scatter to form strided load/store if possible.
490349cc55cSDimitry Andric   for (auto *II : Gathers)
491349cc55cSDimitry Andric     Changed |= tryCreateStridedLoadStore(
492349cc55cSDimitry Andric         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
493349cc55cSDimitry Andric   for (auto *II : Scatters)
494349cc55cSDimitry Andric     Changed |=
495349cc55cSDimitry Andric         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
496349cc55cSDimitry Andric                                   II->getArgOperand(1), II->getArgOperand(2));
497349cc55cSDimitry Andric 
498349cc55cSDimitry Andric   // Remove any dead phis.
499349cc55cSDimitry Andric   while (!MaybeDeadPHIs.empty()) {
500349cc55cSDimitry Andric     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
501349cc55cSDimitry Andric       RecursivelyDeleteDeadPHINode(Phi);
502349cc55cSDimitry Andric   }
503349cc55cSDimitry Andric 
504349cc55cSDimitry Andric   return Changed;
505349cc55cSDimitry Andric }
506