xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include <algorithm>
43 #include <cassert>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48 
49 cl::opt<bool> EnableMaskedGatherScatters(
50     "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51     cl::desc("Enable the generation of masked gathers and scatters"));
52 
53 namespace {
54 
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57   static char ID; // Pass identification, replacement for typeid
58 
MVEGatherScatterLowering()59   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61   }
62 
63   bool runOnFunction(Function &F) override;
64 
getPassName() const65   StringRef getPassName() const override {
66     return "MVE gather/scatter lowering";
67   }
68 
getAnalysisUsage(AnalysisUsage & AU) const69   void getAnalysisUsage(AnalysisUsage &AU) const override {
70     AU.setPreservesCFG();
71     AU.addRequired<TargetPassConfig>();
72     AU.addRequired<LoopInfoWrapperPass>();
73     FunctionPass::getAnalysisUsage(AU);
74   }
75 
76 private:
77   LoopInfo *LI = nullptr;
78 
79   // Check this is a valid gather with correct alignment
80   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81                                Align Alignment);
82   // Check whether Ptr is hidden behind a bitcast and look through it
83   void lookThroughBitcast(Value *&Ptr);
84   // Check for a getelementptr and deduce base and offsets from it, on success
85   // returning the base directly and the offsets indirectly using the Offsets
86   // argument
87   Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP,
88                   IRBuilder<> &Builder);
89   // Compute the scale of this gather/scatter instruction
90   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91   // If the value is a constant, or derived from constants via additions
92   // and multilications, return its numeric value
93   Optional<int64_t> getIfConst(const Value *V);
94   // If Inst is an add instruction, check whether one summand is a
95   // constant. If so, scale this constant and return it together with
96   // the other summand.
97   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98 
99   Value *lowerGather(IntrinsicInst *I);
100   // Create a gather from a base + vector of offsets
101   Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102                                      Instruction *&Root, IRBuilder<> &Builder);
103   // Create a gather from a vector of pointers
104   Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105                                    IRBuilder<> &Builder, int64_t Increment = 0);
106   // Create an incrementing gather from a vector of pointers
107   Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108                                      IRBuilder<> &Builder,
109                                      int64_t Increment = 0);
110 
111   Value *lowerScatter(IntrinsicInst *I);
112   // Create a scatter to a base + vector of offsets
113   Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114                                       IRBuilder<> &Builder);
115   // Create a scatter to a vector of pointers
116   Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117                                     IRBuilder<> &Builder,
118                                     int64_t Increment = 0);
119   // Create an incrementing scatter from a vector of pointers
120   Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121                                       IRBuilder<> &Builder,
122                                       int64_t Increment = 0);
123 
124   // QI gathers and scatters can increment their offsets on their own if
125   // the increment is a constant value (digit)
126   Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127                                       Value *Ptr, GetElementPtrInst *GEP,
128                                       IRBuilder<> &Builder);
129   // QI gathers/scatters can increment their offsets on their own if the
130   // increment is a constant value (digit) - this creates a writeback QI
131   // gather/scatter
132   Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133                                         Value *Ptr, unsigned TypeScale,
134                                         IRBuilder<> &Builder);
135 
136   // Optimise the base and offsets of the given address
137   bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
138   // Try to fold consecutive geps together into one
139   Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
140   // Check whether these offsets could be moved out of the loop they're in
141   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
142   // Pushes the given add out of the loop
143   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
144   // Pushes the given mul out of the loop
145   void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
146                   Value *OffsSecondOperand, unsigned LoopIncrement,
147                   IRBuilder<> &Builder);
148 };
149 
150 } // end anonymous namespace
151 
152 char MVEGatherScatterLowering::ID = 0;
153 
154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
155                 "MVE gather/scattering lowering pass", false, false)
156 
createMVEGatherScatterLoweringPass()157 Pass *llvm::createMVEGatherScatterLoweringPass() {
158   return new MVEGatherScatterLowering();
159 }
160 
isLegalTypeAndAlignment(unsigned NumElements,unsigned ElemSize,Align Alignment)161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
162                                                        unsigned ElemSize,
163                                                        Align Alignment) {
164   if (((NumElements == 4 &&
165         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167        (NumElements == 16 && ElemSize == 8)) &&
168       Alignment >= ElemSize / 8)
169     return true;
170   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
171                     << "valid alignment or vector type \n");
172   return false;
173 }
174 
checkOffsetSize(Value * Offsets,unsigned TargetElemCount)175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
176   // Offsets that are not of type <N x i32> are sign extended by the
177   // getelementptr instruction, and MVE gathers/scatters treat the offset as
178   // unsigned. Thus, if the element size is smaller than 32, we can only allow
179   // positive offsets - i.e., the offsets are not allowed to be variables we
180   // can't look into.
181   // Additionally, <N x i32> offsets have to either originate from a zext of a
182   // vector with element types smaller or equal the type of the gather we're
183   // looking at, or consist of constants that we can check are small enough
184   // to fit into the gather type.
185   // Thus we check that 0 < value < 2^TargetElemSize.
186   unsigned TargetElemSize = 128 / TargetElemCount;
187   unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
188                                 ->getElementType()
189                                 ->getScalarSizeInBits();
190   if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
191     Constant *ConstOff = dyn_cast<Constant>(Offsets);
192     if (!ConstOff)
193       return false;
194     int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195     auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
196       ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
197       if (!OConst)
198         return false;
199       int SExtValue = OConst->getSExtValue();
200       if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
201         return false;
202       return true;
203     };
204     if (isa<FixedVectorType>(ConstOff->getType())) {
205       for (unsigned i = 0; i < TargetElemCount; i++) {
206         if (!CheckValueSize(ConstOff->getAggregateElement(i)))
207           return false;
208       }
209     } else {
210       if (!CheckValueSize(ConstOff))
211         return false;
212     }
213   }
214   return true;
215 }
216 
checkGEP(Value * & Offsets,FixedVectorType * Ty,GetElementPtrInst * GEP,IRBuilder<> & Builder)217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
218                                           GetElementPtrInst *GEP,
219                                           IRBuilder<> &Builder) {
220   if (!GEP) {
221     LLVM_DEBUG(
222         dbgs() << "masked gathers/scatters: no getelementpointer found\n");
223     return nullptr;
224   }
225   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
226                     << " Looking at intrinsic for base + vector of offsets\n");
227   Value *GEPPtr = GEP->getPointerOperand();
228   Offsets = GEP->getOperand(1);
229   if (GEPPtr->getType()->isVectorTy() ||
230       !isa<FixedVectorType>(Offsets->getType()))
231     return nullptr;
232 
233   if (GEP->getNumOperands() != 2) {
234     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
235                       << " operands. Expanding.\n");
236     return nullptr;
237   }
238   Offsets = GEP->getOperand(1);
239   unsigned OffsetsElemCount =
240       cast<FixedVectorType>(Offsets->getType())->getNumElements();
241   // Paranoid check whether the number of parallel lanes is the same
242   assert(Ty->getNumElements() == OffsetsElemCount);
243 
244   ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
245   if (ZextOffs)
246     Offsets = ZextOffs->getOperand(0);
247   FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
248 
249   // If the offsets are already being zext-ed to <N x i32>, that relieves us of
250   // having to make sure that they won't overflow.
251   if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
252                            ->getElementType()
253                            ->getScalarSizeInBits() != 32)
254     if (!checkOffsetSize(Offsets, OffsetsElemCount))
255       return nullptr;
256 
257   // The offset sizes have been checked; if any truncating or zext-ing is
258   // required to fix them, do that now
259   if (Ty != Offsets->getType()) {
260     if ((Ty->getElementType()->getScalarSizeInBits() <
261          OffsetType->getElementType()->getScalarSizeInBits())) {
262       Offsets = Builder.CreateTrunc(Offsets, Ty);
263     } else {
264       Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
265     }
266   }
267   // If none of the checks failed, return the gep's base pointer
268   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
269   return GEPPtr;
270 }
271 
lookThroughBitcast(Value * & Ptr)272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
273   // Look through bitcast instruction if #elements is the same
274   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
278       LLVM_DEBUG(
279           dbgs() << "masked gathers/scatters: looking through bitcast\n");
280       Ptr = BitCast->getOperand(0);
281     }
282   }
283 }
284 
computeScale(unsigned GEPElemSize,unsigned MemoryElemSize)285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
286                                            unsigned MemoryElemSize) {
287   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
288   // or a 8bit, 16bit or 32bit load/store scaled by 1
289   if (GEPElemSize == 32 && MemoryElemSize == 32)
290     return 2;
291   else if (GEPElemSize == 16 && MemoryElemSize == 16)
292     return 1;
293   else if (GEPElemSize == 8)
294     return 0;
295   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
296                     << "create intrinsic\n");
297   return -1;
298 }
299 
getIfConst(const Value * V)300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
301   const Constant *C = dyn_cast<Constant>(V);
302   if (C != nullptr)
303     return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
304   if (!isa<Instruction>(V))
305     return Optional<int64_t>{};
306 
307   const Instruction *I = cast<Instruction>(V);
308   if (I->getOpcode() == Instruction::Add ||
309               I->getOpcode() == Instruction::Mul) {
310     Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
311     Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
312     if (!Op0 || !Op1)
313       return Optional<int64_t>{};
314     if (I->getOpcode() == Instruction::Add)
315       return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
316     if (I->getOpcode() == Instruction::Mul)
317       return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
318   }
319   return Optional<int64_t>{};
320 }
321 
322 std::pair<Value *, int64_t>
getVarAndConst(Value * Inst,int TypeScale)323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
324   std::pair<Value *, int64_t> ReturnFalse =
325       std::pair<Value *, int64_t>(nullptr, 0);
326   // At this point, the instruction we're looking at must be an add or we
327   // bail out
328   Instruction *Add = dyn_cast<Instruction>(Inst);
329   if (Add == nullptr || Add->getOpcode() != Instruction::Add)
330     return ReturnFalse;
331 
332   Value *Summand;
333   Optional<int64_t> Const;
334   // Find out which operand the value that is increased is
335   if ((Const = getIfConst(Add->getOperand(0))))
336     Summand = Add->getOperand(1);
337   else if ((Const = getIfConst(Add->getOperand(1))))
338     Summand = Add->getOperand(0);
339   else
340     return ReturnFalse;
341 
342   // Check that the constant is small enough for an incrementing gather
343   int64_t Immediate = Const.getValue() << TypeScale;
344   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
345     return ReturnFalse;
346 
347   return std::pair<Value *, int64_t>(Summand, Immediate);
348 }
349 
lowerGather(IntrinsicInst * I)350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
351   using namespace PatternMatch;
352   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
353                     << *I << "\n");
354 
355   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
356   // Attempt to turn the masked gather in I into a MVE intrinsic
357   // Potentially optimising the addressing modes as we do so.
358   auto *Ty = cast<FixedVectorType>(I->getType());
359   Value *Ptr = I->getArgOperand(0);
360   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
361   Value *Mask = I->getArgOperand(2);
362   Value *PassThru = I->getArgOperand(3);
363 
364   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
365                                Alignment))
366     return nullptr;
367   lookThroughBitcast(Ptr);
368   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
369 
370   IRBuilder<> Builder(I->getContext());
371   Builder.SetInsertPoint(I);
372   Builder.SetCurrentDebugLocation(I->getDebugLoc());
373 
374   Instruction *Root = I;
375   Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
376   if (!Load)
377     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
378   if (!Load)
379     return nullptr;
380 
381   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
382     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
383                       << "creating select\n");
384     Load = Builder.CreateSelect(Mask, Load, PassThru);
385   }
386 
387   Root->replaceAllUsesWith(Load);
388   Root->eraseFromParent();
389   if (Root != I)
390     // If this was an extending gather, we need to get rid of the sext/zext
391     // sext/zext as well as of the gather itself
392     I->eraseFromParent();
393 
394   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
395                     << *Load << "\n");
396   return Load;
397 }
398 
tryCreateMaskedGatherBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)399 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
400                                                            Value *Ptr,
401                                                            IRBuilder<> &Builder,
402                                                            int64_t Increment) {
403   using namespace PatternMatch;
404   auto *Ty = cast<FixedVectorType>(I->getType());
405   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
406   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
407     // Can't build an intrinsic for this
408     return nullptr;
409   Value *Mask = I->getArgOperand(2);
410   if (match(Mask, m_One()))
411     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
412                                    {Ty, Ptr->getType()},
413                                    {Ptr, Builder.getInt32(Increment)});
414   else
415     return Builder.CreateIntrinsic(
416         Intrinsic::arm_mve_vldr_gather_base_predicated,
417         {Ty, Ptr->getType(), Mask->getType()},
418         {Ptr, Builder.getInt32(Increment), Mask});
419 }
420 
tryCreateMaskedGatherBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)421 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
422     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
423   using namespace PatternMatch;
424   auto *Ty = cast<FixedVectorType>(I->getType());
425   LLVM_DEBUG(
426       dbgs()
427       << "masked gathers: loading from vector of pointers with writeback\n");
428   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
429     // Can't build an intrinsic for this
430     return nullptr;
431   Value *Mask = I->getArgOperand(2);
432   if (match(Mask, m_One()))
433     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
434                                    {Ty, Ptr->getType()},
435                                    {Ptr, Builder.getInt32(Increment)});
436   else
437     return Builder.CreateIntrinsic(
438         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
439         {Ty, Ptr->getType(), Mask->getType()},
440         {Ptr, Builder.getInt32(Increment), Mask});
441 }
442 
tryCreateMaskedGatherOffset(IntrinsicInst * I,Value * Ptr,Instruction * & Root,IRBuilder<> & Builder)443 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
444     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
445   using namespace PatternMatch;
446 
447   Type *OriginalTy = I->getType();
448   Type *ResultTy = OriginalTy;
449 
450   unsigned Unsigned = 1;
451   // The size of the gather was already checked in isLegalTypeAndAlignment;
452   // if it was not a full vector width an appropriate extend should follow.
453   auto *Extend = Root;
454   if (OriginalTy->getPrimitiveSizeInBits() < 128) {
455     // Only transform gathers with exactly one use
456     if (!I->hasOneUse())
457       return nullptr;
458 
459     // The correct root to replace is not the CallInst itself, but the
460     // instruction which extends it
461     Extend = cast<Instruction>(*I->users().begin());
462     if (isa<SExtInst>(Extend)) {
463       Unsigned = 0;
464     } else if (!isa<ZExtInst>(Extend)) {
465       LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
466                         << "Expanding\n");
467       return nullptr;
468     }
469     LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
470     ResultTy = Extend->getType();
471     // The final size of the gather must be a full vector width
472     if (ResultTy->getPrimitiveSizeInBits() != 128) {
473       LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
474                         << "Expanding\n");
475       return nullptr;
476     }
477   }
478 
479   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
480   Value *Offsets;
481   Value *BasePtr =
482       checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
483   if (!BasePtr)
484     return nullptr;
485   // Check whether the offset is a constant increment that could be merged into
486   // a QI gather
487   Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
488   if (Load)
489     return Load;
490 
491   int Scale = computeScale(
492       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
493       OriginalTy->getScalarSizeInBits());
494   if (Scale == -1)
495     return nullptr;
496   Root = Extend;
497 
498   Value *Mask = I->getArgOperand(2);
499   if (!match(Mask, m_One()))
500     return Builder.CreateIntrinsic(
501         Intrinsic::arm_mve_vldr_gather_offset_predicated,
502         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
503         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
504          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
505   else
506     return Builder.CreateIntrinsic(
507         Intrinsic::arm_mve_vldr_gather_offset,
508         {ResultTy, BasePtr->getType(), Offsets->getType()},
509         {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
510          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
511 }
512 
lowerScatter(IntrinsicInst * I)513 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
514   using namespace PatternMatch;
515   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
516                     << *I << "\n");
517 
518   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
519   // Attempt to turn the masked scatter in I into a MVE intrinsic
520   // Potentially optimising the addressing modes as we do so.
521   Value *Input = I->getArgOperand(0);
522   Value *Ptr = I->getArgOperand(1);
523   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
524   auto *Ty = cast<FixedVectorType>(Input->getType());
525 
526   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
527                                Alignment))
528     return nullptr;
529 
530   lookThroughBitcast(Ptr);
531   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
532 
533   IRBuilder<> Builder(I->getContext());
534   Builder.SetInsertPoint(I);
535   Builder.SetCurrentDebugLocation(I->getDebugLoc());
536 
537   Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
538   if (!Store)
539     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
540   if (!Store)
541     return nullptr;
542 
543   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
544                     << *Store << "\n");
545   I->eraseFromParent();
546   return Store;
547 }
548 
tryCreateMaskedScatterBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)549 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
550     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
551   using namespace PatternMatch;
552   Value *Input = I->getArgOperand(0);
553   auto *Ty = cast<FixedVectorType>(Input->getType());
554   // Only QR variants allow truncating
555   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
556     // Can't build an intrinsic for this
557     return nullptr;
558   }
559   Value *Mask = I->getArgOperand(3);
560   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
561   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
562   if (match(Mask, m_One()))
563     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
564                                    {Ptr->getType(), Input->getType()},
565                                    {Ptr, Builder.getInt32(Increment), Input});
566   else
567     return Builder.CreateIntrinsic(
568         Intrinsic::arm_mve_vstr_scatter_base_predicated,
569         {Ptr->getType(), Input->getType(), Mask->getType()},
570         {Ptr, Builder.getInt32(Increment), Input, Mask});
571 }
572 
tryCreateMaskedScatterBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)573 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
574     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
575   using namespace PatternMatch;
576   Value *Input = I->getArgOperand(0);
577   auto *Ty = cast<FixedVectorType>(Input->getType());
578   LLVM_DEBUG(
579       dbgs()
580       << "masked scatters: storing to a vector of pointers with writeback\n");
581   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
582     // Can't build an intrinsic for this
583     return nullptr;
584   Value *Mask = I->getArgOperand(3);
585   if (match(Mask, m_One()))
586     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
587                                    {Ptr->getType(), Input->getType()},
588                                    {Ptr, Builder.getInt32(Increment), Input});
589   else
590     return Builder.CreateIntrinsic(
591         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
592         {Ptr->getType(), Input->getType(), Mask->getType()},
593         {Ptr, Builder.getInt32(Increment), Input, Mask});
594 }
595 
tryCreateMaskedScatterOffset(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder)596 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
597     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
598   using namespace PatternMatch;
599   Value *Input = I->getArgOperand(0);
600   Value *Mask = I->getArgOperand(3);
601   Type *InputTy = Input->getType();
602   Type *MemoryTy = InputTy;
603   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
604                     << " to base + vector of offsets\n");
605   // If the input has been truncated, try to integrate that trunc into the
606   // scatter instruction (we don't care about alignment here)
607   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
608     Value *PreTrunc = Trunc->getOperand(0);
609     Type *PreTruncTy = PreTrunc->getType();
610     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
611       Input = PreTrunc;
612       InputTy = PreTruncTy;
613     }
614   }
615   if (InputTy->getPrimitiveSizeInBits() != 128) {
616     LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
617                          "non-standard input types. Expanding.\n");
618     return nullptr;
619   }
620 
621   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
622   Value *Offsets;
623   Value *BasePtr =
624       checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
625   if (!BasePtr)
626     return nullptr;
627   // Check whether the offset is a constant increment that could be merged into
628   // a QI gather
629   Value *Store =
630       tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
631   if (Store)
632     return Store;
633   int Scale = computeScale(
634       BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
635       MemoryTy->getScalarSizeInBits());
636   if (Scale == -1)
637     return nullptr;
638 
639   if (!match(Mask, m_One()))
640     return Builder.CreateIntrinsic(
641         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
642         {BasePtr->getType(), Offsets->getType(), Input->getType(),
643          Mask->getType()},
644         {BasePtr, Offsets, Input,
645          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
646          Builder.getInt32(Scale), Mask});
647   else
648     return Builder.CreateIntrinsic(
649         Intrinsic::arm_mve_vstr_scatter_offset,
650         {BasePtr->getType(), Offsets->getType(), Input->getType()},
651         {BasePtr, Offsets, Input,
652          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
653          Builder.getInt32(Scale)});
654 }
655 
tryCreateIncrementingGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,GetElementPtrInst * GEP,IRBuilder<> & Builder)656 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
657     IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
658     IRBuilder<> &Builder) {
659   FixedVectorType *Ty;
660   if (I->getIntrinsicID() == Intrinsic::masked_gather)
661     Ty = cast<FixedVectorType>(I->getType());
662   else
663     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
664   // Incrementing gathers only exist for v4i32
665   if (Ty->getNumElements() != 4 ||
666       Ty->getScalarSizeInBits() != 32)
667     return nullptr;
668   Loop *L = LI->getLoopFor(I->getParent());
669   if (L == nullptr)
670     // Incrementing gathers are not beneficial outside of a loop
671     return nullptr;
672   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
673                        "wb gather/scatter\n");
674 
675   // The gep was in charge of making sure the offsets are scaled correctly
676   // - calculate that factor so it can be applied by hand
677   DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
678   int TypeScale =
679       computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
680                    DT.getTypeSizeInBits(GEP->getType()) /
681                        cast<FixedVectorType>(GEP->getType())->getNumElements());
682   if (TypeScale == -1)
683     return nullptr;
684 
685   if (GEP->hasOneUse()) {
686     // Only in this case do we want to build a wb gather, because the wb will
687     // change the phi which does affect other users of the gep (which will still
688     // be using the phi in the old way)
689     Value *Load =
690         tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
691     if (Load != nullptr)
692       return Load;
693   }
694   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
695                        "non-wb gather/scatter\n");
696 
697   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
698   if (Add.first == nullptr)
699     return nullptr;
700   Value *OffsetsIncoming = Add.first;
701   int64_t Immediate = Add.second;
702 
703   // Make sure the offsets are scaled correctly
704   Instruction *ScaledOffsets = BinaryOperator::Create(
705       Instruction::Shl, OffsetsIncoming,
706       Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
707       "ScaledIndex", I);
708   // Add the base to the offsets
709   OffsetsIncoming = BinaryOperator::Create(
710       Instruction::Add, ScaledOffsets,
711       Builder.CreateVectorSplat(
712           Ty->getNumElements(),
713           Builder.CreatePtrToInt(
714               BasePtr,
715               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
716       "StartIndex", I);
717 
718   if (I->getIntrinsicID() == Intrinsic::masked_gather)
719     return cast<IntrinsicInst>(
720         tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
721   else
722     return cast<IntrinsicInst>(
723         tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
724 }
725 
tryCreateIncrementingWBGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,unsigned TypeScale,IRBuilder<> & Builder)726 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
727     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
728     IRBuilder<> &Builder) {
729   // Check whether this gather's offset is incremented by a constant - if so,
730   // and the load is of the right type, we can merge this into a QI gather
731   Loop *L = LI->getLoopFor(I->getParent());
732   // Offsets that are worth merging into this instruction will be incremented
733   // by a constant, thus we're looking for an add of a phi and a constant
734   PHINode *Phi = dyn_cast<PHINode>(Offsets);
735   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
736       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
737     // No phi means no IV to write back to; if there is a phi, we expect it
738     // to have exactly two incoming values; the only phis we are interested in
739     // will be loop IV's and have exactly two uses, one in their increment and
740     // one in the gather's gep
741     return nullptr;
742 
743   unsigned IncrementIndex =
744       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
745   // Look through the phi to the phi increment
746   Offsets = Phi->getIncomingValue(IncrementIndex);
747 
748   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
749   if (Add.first == nullptr)
750     return nullptr;
751   Value *OffsetsIncoming = Add.first;
752   int64_t Immediate = Add.second;
753   if (OffsetsIncoming != Phi)
754     // Then the increment we are looking at is not an increment of the
755     // induction variable, and we don't want to do a writeback
756     return nullptr;
757 
758   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
759   unsigned NumElems =
760       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
761 
762   // Make sure the offsets are scaled correctly
763   Instruction *ScaledOffsets = BinaryOperator::Create(
764       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
765       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
766       "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
767   // Add the base to the offsets
768   OffsetsIncoming = BinaryOperator::Create(
769       Instruction::Add, ScaledOffsets,
770       Builder.CreateVectorSplat(
771           NumElems,
772           Builder.CreatePtrToInt(
773               BasePtr,
774               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
775       "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
776   // The gather is pre-incrementing
777   OffsetsIncoming = BinaryOperator::Create(
778       Instruction::Sub, OffsetsIncoming,
779       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
780       "PreIncrementStartIndex",
781       &Phi->getIncomingBlock(1 - IncrementIndex)->back());
782   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
783 
784   Builder.SetInsertPoint(I);
785 
786   Value *EndResult;
787   Value *NewInduction;
788   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
789     // Build the incrementing gather
790     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
791     // One value to be handed to whoever uses the gather, one is the loop
792     // increment
793     EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
794     NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
795   } else {
796     // Build the incrementing scatter
797     NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
798     EndResult = NewInduction;
799   }
800   Instruction *AddInst = cast<Instruction>(Offsets);
801   AddInst->replaceAllUsesWith(NewInduction);
802   AddInst->eraseFromParent();
803   Phi->setIncomingValue(IncrementIndex, NewInduction);
804 
805   return EndResult;
806 }
807 
pushOutAdd(PHINode * & Phi,Value * OffsSecondOperand,unsigned StartIndex)808 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
809                                           Value *OffsSecondOperand,
810                                           unsigned StartIndex) {
811   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
812   Instruction *InsertionPoint =
813         &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
814   // Initialize the phi with a vector that contains a sum of the constants
815   Instruction *NewIndex = BinaryOperator::Create(
816       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
817       "PushedOutAdd", InsertionPoint);
818   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
819 
820   // Order such that start index comes first (this reduces mov's)
821   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
822   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
823                    Phi->getIncomingBlock(IncrementIndex));
824   Phi->removeIncomingValue(IncrementIndex);
825   Phi->removeIncomingValue(StartIndex);
826 }
827 
pushOutMul(PHINode * & Phi,Value * IncrementPerRound,Value * OffsSecondOperand,unsigned LoopIncrement,IRBuilder<> & Builder)828 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
829                                           Value *IncrementPerRound,
830                                           Value *OffsSecondOperand,
831                                           unsigned LoopIncrement,
832                                           IRBuilder<> &Builder) {
833   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
834 
835   // Create a new scalar add outside of the loop and transform it to a splat
836   // by which loop variable can be incremented
837   Instruction *InsertionPoint = &cast<Instruction>(
838         Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
839 
840   // Create a new index
841   Value *StartIndex = BinaryOperator::Create(
842       Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
843       OffsSecondOperand, "PushedOutMul", InsertionPoint);
844 
845   Instruction *Product =
846       BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
847                              OffsSecondOperand, "Product", InsertionPoint);
848   // Increment NewIndex by Product instead of the multiplication
849   Instruction *NewIncrement = BinaryOperator::Create(
850       Instruction::Add, Phi, Product, "IncrementPushedOutMul",
851       cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
852           .getPrevNode());
853 
854   Phi->addIncoming(StartIndex,
855                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
856   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
857   Phi->removeIncomingValue((unsigned)0);
858   Phi->removeIncomingValue((unsigned)0);
859 }
860 
861 // Check whether all usages of this instruction are as offsets of
862 // gathers/scatters or simple arithmetics only used by gathers/scatters
hasAllGatScatUsers(Instruction * I)863 static bool hasAllGatScatUsers(Instruction *I) {
864   if (I->hasNUses(0)) {
865     return false;
866   }
867   bool Gatscat = true;
868   for (User *U : I->users()) {
869     if (!isa<Instruction>(U))
870       return false;
871     if (isa<GetElementPtrInst>(U) ||
872         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
873       return Gatscat;
874     } else {
875       unsigned OpCode = cast<Instruction>(U)->getOpcode();
876       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
877           hasAllGatScatUsers(cast<Instruction>(U))) {
878         continue;
879       }
880       return false;
881     }
882   }
883   return Gatscat;
884 }
885 
optimiseOffsets(Value * Offsets,BasicBlock * BB,LoopInfo * LI)886 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
887                                                LoopInfo *LI) {
888   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
889                     << *Offsets << "\n");
890   // Optimise the addresses of gathers/scatters by moving invariant
891   // calculations out of the loop
892   if (!isa<Instruction>(Offsets))
893     return false;
894   Instruction *Offs = cast<Instruction>(Offsets);
895   if (Offs->getOpcode() != Instruction::Add &&
896       Offs->getOpcode() != Instruction::Mul)
897     return false;
898   Loop *L = LI->getLoopFor(BB);
899   if (L == nullptr)
900     return false;
901   if (!Offs->hasOneUse()) {
902     if (!hasAllGatScatUsers(Offs))
903       return false;
904   }
905 
906   // Find out which, if any, operand of the instruction
907   // is a phi node
908   PHINode *Phi;
909   int OffsSecondOp;
910   if (isa<PHINode>(Offs->getOperand(0))) {
911     Phi = cast<PHINode>(Offs->getOperand(0));
912     OffsSecondOp = 1;
913   } else if (isa<PHINode>(Offs->getOperand(1))) {
914     Phi = cast<PHINode>(Offs->getOperand(1));
915     OffsSecondOp = 0;
916   } else {
917     bool Changed = true;
918     if (isa<Instruction>(Offs->getOperand(0)) &&
919         L->contains(cast<Instruction>(Offs->getOperand(0))))
920       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
921     if (isa<Instruction>(Offs->getOperand(1)) &&
922         L->contains(cast<Instruction>(Offs->getOperand(1))))
923       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
924     if (!Changed) {
925       return false;
926     } else {
927       if (isa<PHINode>(Offs->getOperand(0))) {
928         Phi = cast<PHINode>(Offs->getOperand(0));
929         OffsSecondOp = 1;
930       } else if (isa<PHINode>(Offs->getOperand(1))) {
931         Phi = cast<PHINode>(Offs->getOperand(1));
932         OffsSecondOp = 0;
933       } else {
934         return false;
935       }
936     }
937   }
938   // A phi node we want to perform this function on should be from the
939   // loop header, and shouldn't have more than 2 incoming values
940   if (Phi->getParent() != L->getHeader() ||
941       Phi->getNumIncomingValues() != 2)
942     return false;
943 
944   // The phi must be an induction variable
945   int IncrementingBlock = -1;
946 
947   for (int i = 0; i < 2; i++)
948     if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i)))
949       if (Op->getOpcode() == Instruction::Add &&
950           (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
951         IncrementingBlock = i;
952   if (IncrementingBlock == -1)
953     return false;
954 
955   Instruction *IncInstruction =
956       cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
957 
958   // If the phi is not used by anything else, we can just adapt it when
959   // replacing the instruction; if it is, we'll have to duplicate it
960   PHINode *NewPhi;
961   Value *IncrementPerRound = IncInstruction->getOperand(
962       (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
963 
964   // Get the value that is added to/multiplied with the phi
965   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
966 
967   if (IncrementPerRound->getType() != OffsSecondOperand->getType())
968     // Something has gone wrong, abort
969     return false;
970 
971   // Only proceed if the increment per round is a constant or an instruction
972   // which does not originate from within the loop
973   if (!isa<Constant>(IncrementPerRound) &&
974       !(isa<Instruction>(IncrementPerRound) &&
975         !L->contains(cast<Instruction>(IncrementPerRound))))
976     return false;
977 
978   if (Phi->getNumUses() == 2) {
979     // No other users -> reuse existing phi (One user is the instruction
980     // we're looking at, the other is the phi increment)
981     if (IncInstruction->getNumUses() != 1) {
982       // If the incrementing instruction does have more users than
983       // our phi, we need to copy it
984       IncInstruction = BinaryOperator::Create(
985           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
986           IncrementPerRound, "LoopIncrement", IncInstruction);
987       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
988     }
989     NewPhi = Phi;
990   } else {
991     // There are other users -> create a new phi
992     NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
993     std::vector<Value *> Increases;
994     // Copy the incoming values of the old phi
995     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
996                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
997     IncInstruction = BinaryOperator::Create(
998         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
999         IncrementPerRound, "LoopIncrement", IncInstruction);
1000     NewPhi->addIncoming(IncInstruction,
1001                         Phi->getIncomingBlock(IncrementingBlock));
1002     IncrementingBlock = 1;
1003   }
1004 
1005   IRBuilder<> Builder(BB->getContext());
1006   Builder.SetInsertPoint(Phi);
1007   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1008 
1009   switch (Offs->getOpcode()) {
1010   case Instruction::Add:
1011     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1012     break;
1013   case Instruction::Mul:
1014     pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1015                Builder);
1016     break;
1017   default:
1018     return false;
1019   }
1020   LLVM_DEBUG(
1021       dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
1022 
1023   // The instruction has now been "absorbed" into the phi value
1024   Offs->replaceAllUsesWith(NewPhi);
1025   if (Offs->hasNUses(0))
1026     Offs->eraseFromParent();
1027   // Clean up the old increment in case it's unused because we built a new
1028   // one
1029   if (IncInstruction->hasNUses(0))
1030     IncInstruction->eraseFromParent();
1031 
1032   return true;
1033 }
1034 
CheckAndCreateOffsetAdd(Value * X,Value * Y,Value * GEP,IRBuilder<> & Builder)1035 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
1036                                       IRBuilder<> &Builder) {
1037   // Splat the non-vector value to a vector of the given type - if the value is
1038   // a constant (and its value isn't too big), we can even use this opportunity
1039   // to scale it to the size of the vector elements
1040   auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1041     ConstantInt *Const;
1042     if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1043         VT->getElementType() != NonVectorVal->getType()) {
1044       unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1045       uint64_t N = Const->getZExtValue();
1046       if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1047         NonVectorVal = Builder.CreateVectorSplat(
1048             VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1049         return;
1050       }
1051     }
1052     NonVectorVal =
1053         Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1054   };
1055 
1056   FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1057   FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1058   // If one of X, Y is not a vector, we have to splat it in order
1059   // to add the two of them.
1060   if (XElType && !YElType) {
1061     FixSummands(XElType, Y);
1062     YElType = cast<FixedVectorType>(Y->getType());
1063   } else if (YElType && !XElType) {
1064     FixSummands(YElType, X);
1065     XElType = cast<FixedVectorType>(X->getType());
1066   }
1067   assert(XElType && YElType && "Unknown vector types");
1068   // Check that the summands are of compatible types
1069   if (XElType != YElType) {
1070     LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1071     return nullptr;
1072   }
1073 
1074   if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1075     // Check that by adding the vectors we do not accidentally
1076     // create an overflow
1077     Constant *ConstX = dyn_cast<Constant>(X);
1078     Constant *ConstY = dyn_cast<Constant>(Y);
1079     if (!ConstX || !ConstY)
1080       return nullptr;
1081     unsigned TargetElemSize = 128 / XElType->getNumElements();
1082     for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1083       ConstantInt *ConstXEl =
1084           dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1085       ConstantInt *ConstYEl =
1086           dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1087       if (!ConstXEl || !ConstYEl ||
1088           ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1089               (unsigned)(1 << (TargetElemSize - 1)))
1090         return nullptr;
1091     }
1092   }
1093 
1094   Value *Add = Builder.CreateAdd(X, Y);
1095 
1096   FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1097   if (checkOffsetSize(Add, GEPType->getNumElements()))
1098     return Add;
1099   else
1100     return nullptr;
1101 }
1102 
foldGEP(GetElementPtrInst * GEP,Value * & Offsets,IRBuilder<> & Builder)1103 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1104                                          Value *&Offsets,
1105                                          IRBuilder<> &Builder) {
1106   Value *GEPPtr = GEP->getPointerOperand();
1107   Offsets = GEP->getOperand(1);
1108   // We only merge geps with constant offsets, because only for those
1109   // we can make sure that we do not cause an overflow
1110   if (!isa<Constant>(Offsets))
1111     return nullptr;
1112   GetElementPtrInst *BaseGEP;
1113   if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1114     // Merge the two geps into one
1115     Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1116     if (!BaseBasePtr)
1117       return nullptr;
1118     Offsets =
1119         CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1120     if (Offsets == nullptr)
1121       return nullptr;
1122     return BaseBasePtr;
1123   }
1124   return GEPPtr;
1125 }
1126 
optimiseAddress(Value * Address,BasicBlock * BB,LoopInfo * LI)1127 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1128                                                LoopInfo *LI) {
1129   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1130   if (!GEP)
1131     return false;
1132   bool Changed = false;
1133   if (GEP->hasOneUse() &&
1134       dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1135     IRBuilder<> Builder(GEP->getContext());
1136     Builder.SetInsertPoint(GEP);
1137     Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1138     Value *Offsets;
1139     Value *Base = foldGEP(GEP, Offsets, Builder);
1140     // We only want to merge the geps if there is a real chance that they can be
1141     // used by an MVE gather; thus the offset has to have the correct size
1142     // (always i32 if it is not of vector type) and the base has to be a
1143     // pointer.
1144     if (Offsets && Base && Base != GEP) {
1145       PointerType *BaseType = cast<PointerType>(Base->getType());
1146       GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1147           BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
1148       GEP->replaceAllUsesWith(NewAddress);
1149       GEP = NewAddress;
1150       Changed = true;
1151     }
1152   }
1153   Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1154   return Changed;
1155 }
1156 
runOnFunction(Function & F)1157 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1158   if (!EnableMaskedGatherScatters)
1159     return false;
1160   auto &TPC = getAnalysis<TargetPassConfig>();
1161   auto &TM = TPC.getTM<TargetMachine>();
1162   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1163   if (!ST->hasMVEIntegerOps())
1164     return false;
1165   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1166   SmallVector<IntrinsicInst *, 4> Gathers;
1167   SmallVector<IntrinsicInst *, 4> Scatters;
1168 
1169   bool Changed = false;
1170 
1171   for (BasicBlock &BB : F) {
1172     for (Instruction &I : BB) {
1173       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1174       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1175           isa<FixedVectorType>(II->getType())) {
1176         Gathers.push_back(II);
1177         Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1178       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1179                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1180         Scatters.push_back(II);
1181         Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1182       }
1183     }
1184   }
1185   for (unsigned i = 0; i < Gathers.size(); i++) {
1186     IntrinsicInst *I = Gathers[i];
1187     Value *L = lowerGather(I);
1188     if (L == nullptr)
1189       continue;
1190 
1191     // Get rid of any now dead instructions
1192     SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1193     Changed = true;
1194   }
1195 
1196   for (unsigned i = 0; i < Scatters.size(); i++) {
1197     IntrinsicInst *I = Scatters[i];
1198     Value *S = lowerScatter(I);
1199     if (S == nullptr)
1200       continue;
1201 
1202     // Get rid of any now dead instructions
1203     SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1204     Changed = true;
1205   }
1206   return Changed;
1207 }
1208