xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (revision 0eae32dcef82f6f06de6419a0d623d7def0cc8f6)
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISCV intrinsics.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCV.h"
15 #include "RISCVTargetMachine.h"
16 #include "llvm/Analysis/LoopInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/Analysis/VectorUtils.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/GetElementPtrTypeIterator.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/IR/IntrinsicsRISCV.h"
24 #include "llvm/Transforms/Utils/Local.h"
25 
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
29 
30 namespace {
31 
32 class RISCVGatherScatterLowering : public FunctionPass {
33   const RISCVSubtarget *ST = nullptr;
34   const RISCVTargetLowering *TLI = nullptr;
35   LoopInfo *LI = nullptr;
36   const DataLayout *DL = nullptr;
37 
38   SmallVector<WeakTrackingVH> MaybeDeadPHIs;
39 
40 public:
41   static char ID; // Pass identification, replacement for typeid
42 
43   RISCVGatherScatterLowering() : FunctionPass(ID) {}
44 
45   bool runOnFunction(Function &F) override;
46 
47   void getAnalysisUsage(AnalysisUsage &AU) const override {
48     AU.setPreservesCFG();
49     AU.addRequired<TargetPassConfig>();
50     AU.addRequired<LoopInfoWrapperPass>();
51   }
52 
53   StringRef getPassName() const override {
54     return "RISCV gather/scatter lowering";
55   }
56 
57 private:
58   bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp);
59 
60   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
61                                  Value *AlignOp);
62 
63   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
64                                                      IRBuilder<> &Builder);
65 
66   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
67                               PHINode *&BasePtr, BinaryOperator *&Inc,
68                               IRBuilder<> &Builder);
69 };
70 
71 } // end anonymous namespace
72 
73 char RISCVGatherScatterLowering::ID = 0;
74 
75 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
76                 "RISCV gather/scatter lowering pass", false, false)
77 
78 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
79   return new RISCVGatherScatterLowering();
80 }
81 
82 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType,
83                                                          Value *AlignOp) {
84   Type *ScalarType = DataType->getScalarType();
85   if (!TLI->isLegalElementTypeForRVV(ScalarType))
86     return false;
87 
88   MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
89   if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize())
90     return false;
91 
92   // FIXME: Let the backend type legalize by splitting/widening?
93   EVT DataVT = TLI->getValueType(*DL, DataType);
94   if (!TLI->isTypeLegal(DataVT))
95     return false;
96 
97   return true;
98 }
99 
100 // TODO: Should we consider the mask when looking for a stride?
101 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
102   unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
103 
104   // Check that the start value is a strided constant.
105   auto *StartVal =
106       dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
107   if (!StartVal)
108     return std::make_pair(nullptr, nullptr);
109   APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
110   ConstantInt *Prev = StartVal;
111   for (unsigned i = 1; i != NumElts; ++i) {
112     auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
113     if (!C)
114       return std::make_pair(nullptr, nullptr);
115 
116     APInt LocalStride = C->getValue() - Prev->getValue();
117     if (i == 1)
118       StrideVal = LocalStride;
119     else if (StrideVal != LocalStride)
120       return std::make_pair(nullptr, nullptr);
121 
122     Prev = C;
123   }
124 
125   Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
126 
127   return std::make_pair(StartVal, Stride);
128 }
129 
130 // Recursively, walk about the use-def chain until we find a Phi with a strided
131 // start value. Build and update a scalar recurrence as we unwind the recursion.
132 // We also update the Stride as we unwind. Our goal is to move all of the
133 // arithmetic out of the loop.
134 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
135                                                         Value *&Stride,
136                                                         PHINode *&BasePtr,
137                                                         BinaryOperator *&Inc,
138                                                         IRBuilder<> &Builder) {
139   // Our base case is a Phi.
140   if (auto *Phi = dyn_cast<PHINode>(Index)) {
141     // A phi node we want to perform this function on should be from the
142     // loop header.
143     if (Phi->getParent() != L->getHeader())
144       return false;
145 
146     Value *Step, *Start;
147     if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
148         Inc->getOpcode() != Instruction::Add)
149       return false;
150     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
151     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
152     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
153            "Expected one operand of phi to be Inc");
154 
155     // Only proceed if the step is loop invariant.
156     if (!L->isLoopInvariant(Step))
157       return false;
158 
159     // Step should be a splat.
160     Step = getSplatValue(Step);
161     if (!Step)
162       return false;
163 
164     // Start should be a strided constant.
165     auto *StartC = dyn_cast<Constant>(Start);
166     if (!StartC)
167       return false;
168 
169     std::tie(Start, Stride) = matchStridedConstant(StartC);
170     if (!Start)
171       return false;
172     assert(Stride != nullptr);
173 
174     // Build scalar phi and increment.
175     BasePtr =
176         PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi);
177     Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
178                                     Inc);
179     BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
180     BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
181 
182     // Note that this Phi might be eligible for removal.
183     MaybeDeadPHIs.push_back(Phi);
184     return true;
185   }
186 
187   // Otherwise look for binary operator.
188   auto *BO = dyn_cast<BinaryOperator>(Index);
189   if (!BO)
190     return false;
191 
192   if (BO->getOpcode() != Instruction::Add &&
193       BO->getOpcode() != Instruction::Or &&
194       BO->getOpcode() != Instruction::Mul &&
195       BO->getOpcode() != Instruction::Shl)
196     return false;
197 
198   // Only support shift by constant.
199   if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
200     return false;
201 
202   // We need to be able to treat Or as Add.
203   if (BO->getOpcode() == Instruction::Or &&
204       !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
205     return false;
206 
207   // We should have one operand in the loop and one splat.
208   Value *OtherOp;
209   if (isa<Instruction>(BO->getOperand(0)) &&
210       L->contains(cast<Instruction>(BO->getOperand(0)))) {
211     Index = cast<Instruction>(BO->getOperand(0));
212     OtherOp = BO->getOperand(1);
213   } else if (isa<Instruction>(BO->getOperand(1)) &&
214              L->contains(cast<Instruction>(BO->getOperand(1)))) {
215     Index = cast<Instruction>(BO->getOperand(1));
216     OtherOp = BO->getOperand(0);
217   } else {
218     return false;
219   }
220 
221   // Make sure other op is loop invariant.
222   if (!L->isLoopInvariant(OtherOp))
223     return false;
224 
225   // Make sure we have a splat.
226   Value *SplatOp = getSplatValue(OtherOp);
227   if (!SplatOp)
228     return false;
229 
230   // Recurse up the use-def chain.
231   if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
232     return false;
233 
234   // Locate the Step and Start values from the recurrence.
235   unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
236   unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
237   Value *Step = Inc->getOperand(StepIndex);
238   Value *Start = BasePtr->getOperand(StartBlock);
239 
240   // We need to adjust the start value in the preheader.
241   Builder.SetInsertPoint(
242       BasePtr->getIncomingBlock(StartBlock)->getTerminator());
243   Builder.SetCurrentDebugLocation(DebugLoc());
244 
245   switch (BO->getOpcode()) {
246   default:
247     llvm_unreachable("Unexpected opcode!");
248   case Instruction::Add:
249   case Instruction::Or: {
250     // An add only affects the start value. It's ok to do this for Or because
251     // we already checked that there are no common set bits.
252 
253     // If the start value is Zero, just take the SplatOp.
254     if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
255       Start = SplatOp;
256     else
257       Start = Builder.CreateAdd(Start, SplatOp, "start");
258     BasePtr->setIncomingValue(StartBlock, Start);
259     break;
260   }
261   case Instruction::Mul: {
262     // If the start is zero we don't need to multiply.
263     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
264       Start = Builder.CreateMul(Start, SplatOp, "start");
265 
266     Step = Builder.CreateMul(Step, SplatOp, "step");
267 
268     // If the Stride is 1 just take the SplatOpt.
269     if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
270       Stride = SplatOp;
271     else
272       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
273     Inc->setOperand(StepIndex, Step);
274     BasePtr->setIncomingValue(StartBlock, Start);
275     break;
276   }
277   case Instruction::Shl: {
278     // If the start is zero we don't need to shift.
279     if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
280       Start = Builder.CreateShl(Start, SplatOp, "start");
281     Step = Builder.CreateShl(Step, SplatOp, "step");
282     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
283     Inc->setOperand(StepIndex, Step);
284     BasePtr->setIncomingValue(StartBlock, Start);
285     break;
286   }
287   }
288 
289   return true;
290 }
291 
292 std::pair<Value *, Value *>
293 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
294                                                    IRBuilder<> &Builder) {
295 
296   SmallVector<Value *, 2> Ops(GEP->operands());
297 
298   // Base pointer needs to be a scalar.
299   if (Ops[0]->getType()->isVectorTy())
300     return std::make_pair(nullptr, nullptr);
301 
302   // Make sure we're in a loop and it is in loop simplify form.
303   Loop *L = LI->getLoopFor(GEP->getParent());
304   if (!L || !L->isLoopSimplifyForm())
305     return std::make_pair(nullptr, nullptr);
306 
307   Optional<unsigned> VecOperand;
308   unsigned TypeScale = 0;
309 
310   // Look for a vector operand and scale.
311   gep_type_iterator GTI = gep_type_begin(GEP);
312   for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
313     if (!Ops[i]->getType()->isVectorTy())
314       continue;
315 
316     if (VecOperand)
317       return std::make_pair(nullptr, nullptr);
318 
319     VecOperand = i;
320 
321     TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
322     if (TS.isScalable())
323       return std::make_pair(nullptr, nullptr);
324 
325     TypeScale = TS.getFixedSize();
326   }
327 
328   // We need to find a vector index to simplify.
329   if (!VecOperand)
330     return std::make_pair(nullptr, nullptr);
331 
332   // We can't extract the stride if the arithmetic is done at a different size
333   // than the pointer type. Adding the stride later may not wrap correctly.
334   // Technically we could handle wider indices, but I don't expect that in
335   // practice.
336   Value *VecIndex = Ops[*VecOperand];
337   Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
338   if (VecIndex->getType() != VecIntPtrTy)
339     return std::make_pair(nullptr, nullptr);
340 
341   Value *Stride;
342   BinaryOperator *Inc;
343   PHINode *BasePhi;
344   if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
345     return std::make_pair(nullptr, nullptr);
346 
347   assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
348   unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
349   assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
350          "Expected one operand of phi to be Inc");
351 
352   Builder.SetInsertPoint(GEP);
353 
354   // Replace the vector index with the scalar phi and build a scalar GEP.
355   Ops[*VecOperand] = BasePhi;
356   Type *SourceTy = GEP->getSourceElementType();
357   Value *BasePtr =
358       Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front());
359 
360   // Cast the GEP to an i8*.
361   LLVMContext &Ctx = GEP->getContext();
362   Type *I8PtrTy =
363       Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace());
364   if (BasePtr->getType() != I8PtrTy)
365     BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy);
366 
367   // Final adjustments to stride should go in the start block.
368   Builder.SetInsertPoint(
369       BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
370 
371   // Convert stride to pointer size if needed.
372   Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
373   assert(Stride->getType() == IntPtrTy && "Unexpected type");
374 
375   // Scale the stride by the size of the indexed type.
376   if (TypeScale != 1)
377     Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
378 
379   return std::make_pair(BasePtr, Stride);
380 }
381 
382 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
383                                                            Type *DataType,
384                                                            Value *Ptr,
385                                                            Value *AlignOp) {
386   // Make sure the operation will be supported by the backend.
387   if (!isLegalTypeAndAlignment(DataType, AlignOp))
388     return false;
389 
390   // Pointer should be a GEP.
391   auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
392   if (!GEP)
393     return false;
394 
395   IRBuilder<> Builder(GEP);
396 
397   Value *BasePtr, *Stride;
398   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
399   if (!BasePtr)
400     return false;
401   assert(Stride != nullptr);
402 
403   Builder.SetInsertPoint(II);
404 
405   CallInst *Call;
406   if (II->getIntrinsicID() == Intrinsic::masked_gather)
407     Call = Builder.CreateIntrinsic(
408         Intrinsic::riscv_masked_strided_load,
409         {DataType, BasePtr->getType(), Stride->getType()},
410         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
411   else
412     Call = Builder.CreateIntrinsic(
413         Intrinsic::riscv_masked_strided_store,
414         {DataType, BasePtr->getType(), Stride->getType()},
415         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
416 
417   Call->takeName(II);
418   II->replaceAllUsesWith(Call);
419   II->eraseFromParent();
420 
421   if (GEP->use_empty())
422     RecursivelyDeleteTriviallyDeadInstructions(GEP);
423 
424   return true;
425 }
426 
427 bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
428   if (skipFunction(F))
429     return false;
430 
431   auto &TPC = getAnalysis<TargetPassConfig>();
432   auto &TM = TPC.getTM<RISCVTargetMachine>();
433   ST = &TM.getSubtarget<RISCVSubtarget>(F);
434   if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
435     return false;
436 
437   TLI = ST->getTargetLowering();
438   DL = &F.getParent()->getDataLayout();
439   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
440 
441   SmallVector<IntrinsicInst *, 4> Gathers;
442   SmallVector<IntrinsicInst *, 4> Scatters;
443 
444   bool Changed = false;
445 
446   for (BasicBlock &BB : F) {
447     for (Instruction &I : BB) {
448       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
449       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
450           isa<FixedVectorType>(II->getType())) {
451         Gathers.push_back(II);
452       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
453                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
454         Scatters.push_back(II);
455       }
456     }
457   }
458 
459   // Rewrite gather/scatter to form strided load/store if possible.
460   for (auto *II : Gathers)
461     Changed |= tryCreateStridedLoadStore(
462         II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
463   for (auto *II : Scatters)
464     Changed |=
465         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
466                                   II->getArgOperand(1), II->getArgOperand(2));
467 
468   // Remove any dead phis.
469   while (!MaybeDeadPHIs.empty()) {
470     if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
471       RecursivelyDeleteDeadPHINode(Phi);
472   }
473 
474   return Changed;
475 }
476