xref: /llvm-project/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp (revision ccbbacf0fa98bd386c0a7b3bdfb85c43e7db1a93)
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetLowering.h"
22 #include "llvm/CodeGen/TargetPassConfig.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/IntrinsicsARM.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/InitializePasses.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <cassert>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48 
49 cl::opt<bool> EnableMaskedGatherScatters(
50     "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51     cl::desc("Enable the generation of masked gathers and scatters"));
52 
53 namespace {
54 
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57   static char ID; // Pass identification, replacement for typeid
58 
59   explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60     initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61   }
62 
63   bool runOnFunction(Function &F) override;
64 
65   StringRef getPassName() const override {
66     return "MVE gather/scatter lowering";
67   }
68 
69   void getAnalysisUsage(AnalysisUsage &AU) const override {
70     AU.setPreservesCFG();
71     AU.addRequired<TargetPassConfig>();
72     AU.addRequired<LoopInfoWrapperPass>();
73     FunctionPass::getAnalysisUsage(AU);
74   }
75 
76 private:
77   LoopInfo *LI = nullptr;
78   const DataLayout *DL;
79 
80   // Check this is a valid gather with correct alignment
81   bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
82                                Align Alignment);
83   // Check whether Ptr is hidden behind a bitcast and look through it
84   void lookThroughBitcast(Value *&Ptr);
85   // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
86   // scalar base and vector offsets, or else fallback to using a base of 0 and
87   // offset of Ptr where possible.
88   Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
89                       FixedVectorType *Ty, Type *MemoryTy,
90                       IRBuilder<> &Builder);
91   // Check for a getelementptr and deduce base and offsets from it, on success
92   // returning the base directly and the offsets indirectly using the Offsets
93   // argument
94   Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty,
95                       GetElementPtrInst *GEP, IRBuilder<> &Builder);
96   // Compute the scale of this gather/scatter instruction
97   int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
98   // If the value is a constant, or derived from constants via additions
99   // and multilications, return its numeric value
100   std::optional<int64_t> getIfConst(const Value *V);
101   // If Inst is an add instruction, check whether one summand is a
102   // constant. If so, scale this constant and return it together with
103   // the other summand.
104   std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
105 
106   Instruction *lowerGather(IntrinsicInst *I);
107   // Create a gather from a base + vector of offsets
108   Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
109                                            Instruction *&Root,
110                                            IRBuilder<> &Builder);
111   // Create a gather from a vector of pointers
112   Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
113                                          IRBuilder<> &Builder,
114                                          int64_t Increment = 0);
115   // Create an incrementing gather from a vector of pointers
116   Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
117                                            IRBuilder<> &Builder,
118                                            int64_t Increment = 0);
119 
120   Instruction *lowerScatter(IntrinsicInst *I);
121   // Create a scatter to a base + vector of offsets
122   Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
123                                             IRBuilder<> &Builder);
124   // Create a scatter to a vector of pointers
125   Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
126                                           IRBuilder<> &Builder,
127                                           int64_t Increment = 0);
128   // Create an incrementing scatter from a vector of pointers
129   Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
130                                             IRBuilder<> &Builder,
131                                             int64_t Increment = 0);
132 
133   // QI gathers and scatters can increment their offsets on their own if
134   // the increment is a constant value (digit)
135   Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
136                                             IRBuilder<> &Builder);
137   // QI gathers/scatters can increment their offsets on their own if the
138   // increment is a constant value (digit) - this creates a writeback QI
139   // gather/scatter
140   Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
141                                               Value *Ptr, unsigned TypeScale,
142                                               IRBuilder<> &Builder);
143 
144   // Optimise the base and offsets of the given address
145   bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
146   // Try to fold consecutive geps together into one
147   Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
148                  IRBuilder<> &Builder);
149   // Check whether these offsets could be moved out of the loop they're in
150   bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
151   // Pushes the given add out of the loop
152   void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
153   // Pushes the given mul or shl out of the loop
154   void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
155                      Value *OffsSecondOperand, unsigned LoopIncrement,
156                      IRBuilder<> &Builder);
157 };
158 
159 } // end anonymous namespace
160 
161 char MVEGatherScatterLowering::ID = 0;
162 
163 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
164                 "MVE gather/scattering lowering pass", false, false)
165 
166 Pass *llvm::createMVEGatherScatterLoweringPass() {
167   return new MVEGatherScatterLowering();
168 }
169 
170 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
171                                                        unsigned ElemSize,
172                                                        Align Alignment) {
173   if (((NumElements == 4 &&
174         (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
175        (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
176        (NumElements == 16 && ElemSize == 8)) &&
177       Alignment >= ElemSize / 8)
178     return true;
179   LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
180                     << "valid alignment or vector type \n");
181   return false;
182 }
183 
184 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
185   // Offsets that are not of type <N x i32> are sign extended by the
186   // getelementptr instruction, and MVE gathers/scatters treat the offset as
187   // unsigned. Thus, if the element size is smaller than 32, we can only allow
188   // positive offsets - i.e., the offsets are not allowed to be variables we
189   // can't look into.
190   // Additionally, <N x i32> offsets have to either originate from a zext of a
191   // vector with element types smaller or equal the type of the gather we're
192   // looking at, or consist of constants that we can check are small enough
193   // to fit into the gather type.
194   // Thus we check that 0 < value < 2^TargetElemSize.
195   unsigned TargetElemSize = 128 / TargetElemCount;
196   unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
197                                 ->getElementType()
198                                 ->getScalarSizeInBits();
199   if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
200     Constant *ConstOff = dyn_cast<Constant>(Offsets);
201     if (!ConstOff)
202       return false;
203     int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
204     auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
205       ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
206       if (!OConst)
207         return false;
208       int SExtValue = OConst->getSExtValue();
209       if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
210         return false;
211       return true;
212     };
213     if (isa<FixedVectorType>(ConstOff->getType())) {
214       for (unsigned i = 0; i < TargetElemCount; i++) {
215         if (!CheckValueSize(ConstOff->getAggregateElement(i)))
216           return false;
217       }
218     } else {
219       if (!CheckValueSize(ConstOff))
220         return false;
221     }
222   }
223   return true;
224 }
225 
226 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
227                                               int &Scale, FixedVectorType *Ty,
228                                               Type *MemoryTy,
229                                               IRBuilder<> &Builder) {
230   if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
231     if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
232       Scale =
233           computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
234                        MemoryTy->getScalarSizeInBits());
235       return Scale == -1 ? nullptr : V;
236     }
237   }
238 
239   // If we couldn't use the GEP (or it doesn't exist), attempt to use a
240   // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
241   // elements.
242   FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
243   if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
244     return nullptr;
245   Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
246   Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy());
247   Offsets = Builder.CreatePtrToInt(
248       Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
249   Scale = 0;
250   return BasePtr;
251 }
252 
253 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
254                                               FixedVectorType *Ty,
255                                               GetElementPtrInst *GEP,
256                                               IRBuilder<> &Builder) {
257   if (!GEP) {
258     LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
259                       << "found\n");
260     return nullptr;
261   }
262   LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
263                     << " Looking at intrinsic for base + vector of offsets\n");
264   Value *GEPPtr = GEP->getPointerOperand();
265   Offsets = GEP->getOperand(1);
266   if (GEPPtr->getType()->isVectorTy() ||
267       !isa<FixedVectorType>(Offsets->getType()))
268     return nullptr;
269 
270   if (GEP->getNumOperands() != 2) {
271     LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
272                       << " operands. Expanding.\n");
273     return nullptr;
274   }
275   Offsets = GEP->getOperand(1);
276   unsigned OffsetsElemCount =
277       cast<FixedVectorType>(Offsets->getType())->getNumElements();
278   // Paranoid check whether the number of parallel lanes is the same
279   assert(Ty->getNumElements() == OffsetsElemCount);
280 
281   ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
282   if (ZextOffs)
283     Offsets = ZextOffs->getOperand(0);
284   FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
285 
286   // If the offsets are already being zext-ed to <N x i32>, that relieves us of
287   // having to make sure that they won't overflow.
288   if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
289                            ->getElementType()
290                            ->getScalarSizeInBits() != 32)
291     if (!checkOffsetSize(Offsets, OffsetsElemCount))
292       return nullptr;
293 
294   // The offset sizes have been checked; if any truncating or zext-ing is
295   // required to fix them, do that now
296   if (Ty != Offsets->getType()) {
297     if ((Ty->getElementType()->getScalarSizeInBits() <
298          OffsetType->getElementType()->getScalarSizeInBits())) {
299       Offsets = Builder.CreateTrunc(Offsets, Ty);
300     } else {
301       Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
302     }
303   }
304   // If none of the checks failed, return the gep's base pointer
305   LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
306   return GEPPtr;
307 }
308 
309 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
310   // Look through bitcast instruction if #elements is the same
311   if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
312     auto *BCTy = cast<FixedVectorType>(BitCast->getType());
313     auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
314     if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
315       LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
316                         << "bitcast\n");
317       Ptr = BitCast->getOperand(0);
318     }
319   }
320 }
321 
322 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
323                                            unsigned MemoryElemSize) {
324   // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
325   // or a 8bit, 16bit or 32bit load/store scaled by 1
326   if (GEPElemSize == 32 && MemoryElemSize == 32)
327     return 2;
328   else if (GEPElemSize == 16 && MemoryElemSize == 16)
329     return 1;
330   else if (GEPElemSize == 8)
331     return 0;
332   LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
333                     << "create intrinsic\n");
334   return -1;
335 }
336 
337 std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
338   const Constant *C = dyn_cast<Constant>(V);
339   if (C && C->getSplatValue())
340     return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
341   if (!isa<Instruction>(V))
342     return std::optional<int64_t>{};
343 
344   const Instruction *I = cast<Instruction>(V);
345   if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
346       I->getOpcode() == Instruction::Mul ||
347       I->getOpcode() == Instruction::Shl) {
348     std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
349     std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
350     if (!Op0 || !Op1)
351       return std::optional<int64_t>{};
352     if (I->getOpcode() == Instruction::Add)
353       return std::optional<int64_t>{*Op0 + *Op1};
354     if (I->getOpcode() == Instruction::Mul)
355       return std::optional<int64_t>{*Op0 * *Op1};
356     if (I->getOpcode() == Instruction::Shl)
357       return std::optional<int64_t>{*Op0 << *Op1};
358     if (I->getOpcode() == Instruction::Or)
359       return std::optional<int64_t>{*Op0 | *Op1};
360   }
361   return std::optional<int64_t>{};
362 }
363 
364 // Return true if I is an Or instruction that is equivalent to an add, due to
365 // the operands having no common bits set.
366 static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
367   return I->getOpcode() == Instruction::Or &&
368          haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
369 }
370 
371 std::pair<Value *, int64_t>
372 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
373   std::pair<Value *, int64_t> ReturnFalse =
374       std::pair<Value *, int64_t>(nullptr, 0);
375   // At this point, the instruction we're looking at must be an add or an
376   // add-like-or.
377   Instruction *Add = dyn_cast<Instruction>(Inst);
378   if (Add == nullptr ||
379       (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
380     return ReturnFalse;
381 
382   Value *Summand;
383   std::optional<int64_t> Const;
384   // Find out which operand the value that is increased is
385   if ((Const = getIfConst(Add->getOperand(0))))
386     Summand = Add->getOperand(1);
387   else if ((Const = getIfConst(Add->getOperand(1))))
388     Summand = Add->getOperand(0);
389   else
390     return ReturnFalse;
391 
392   // Check that the constant is small enough for an incrementing gather
393   int64_t Immediate = *Const << TypeScale;
394   if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
395     return ReturnFalse;
396 
397   return std::pair<Value *, int64_t>(Summand, Immediate);
398 }
399 
400 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
401   using namespace PatternMatch;
402   LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
403                     << *I << "\n");
404 
405   // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
406   // Attempt to turn the masked gather in I into a MVE intrinsic
407   // Potentially optimising the addressing modes as we do so.
408   auto *Ty = cast<FixedVectorType>(I->getType());
409   Value *Ptr = I->getArgOperand(0);
410   Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
411   Value *Mask = I->getArgOperand(2);
412   Value *PassThru = I->getArgOperand(3);
413 
414   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
415                                Alignment))
416     return nullptr;
417   lookThroughBitcast(Ptr);
418   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
419 
420   IRBuilder<> Builder(I->getContext());
421   Builder.SetInsertPoint(I);
422   Builder.SetCurrentDebugLocation(I->getDebugLoc());
423 
424   Instruction *Root = I;
425 
426   Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
427   if (!Load)
428     Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
429   if (!Load)
430     Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
431   if (!Load)
432     return nullptr;
433 
434   if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
435     LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
436                       << "creating select\n");
437     Load = SelectInst::Create(Mask, Load, PassThru);
438     Builder.Insert(Load);
439   }
440 
441   Root->replaceAllUsesWith(Load);
442   Root->eraseFromParent();
443   if (Root != I)
444     // If this was an extending gather, we need to get rid of the sext/zext
445     // sext/zext as well as of the gather itself
446     I->eraseFromParent();
447 
448   LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
449                     << *Load << "\n");
450   return Load;
451 }
452 
453 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
454     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
455   using namespace PatternMatch;
456   auto *Ty = cast<FixedVectorType>(I->getType());
457   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
458   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
459     // Can't build an intrinsic for this
460     return nullptr;
461   Value *Mask = I->getArgOperand(2);
462   if (match(Mask, m_One()))
463     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
464                                    {Ty, Ptr->getType()},
465                                    {Ptr, Builder.getInt32(Increment)});
466   else
467     return Builder.CreateIntrinsic(
468         Intrinsic::arm_mve_vldr_gather_base_predicated,
469         {Ty, Ptr->getType(), Mask->getType()},
470         {Ptr, Builder.getInt32(Increment), Mask});
471 }
472 
473 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
474     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
475   using namespace PatternMatch;
476   auto *Ty = cast<FixedVectorType>(I->getType());
477   LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
478                     << "writeback\n");
479   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
480     // Can't build an intrinsic for this
481     return nullptr;
482   Value *Mask = I->getArgOperand(2);
483   if (match(Mask, m_One()))
484     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
485                                    {Ty, Ptr->getType()},
486                                    {Ptr, Builder.getInt32(Increment)});
487   else
488     return Builder.CreateIntrinsic(
489         Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
490         {Ty, Ptr->getType(), Mask->getType()},
491         {Ptr, Builder.getInt32(Increment), Mask});
492 }
493 
494 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
495     IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
496   using namespace PatternMatch;
497 
498   Type *MemoryTy = I->getType();
499   Type *ResultTy = MemoryTy;
500 
501   unsigned Unsigned = 1;
502   // The size of the gather was already checked in isLegalTypeAndAlignment;
503   // if it was not a full vector width an appropriate extend should follow.
504   auto *Extend = Root;
505   bool TruncResult = false;
506   if (MemoryTy->getPrimitiveSizeInBits() < 128) {
507     if (I->hasOneUse()) {
508       // If the gather has a single extend of the correct type, use an extending
509       // gather and replace the ext. In which case the correct root to replace
510       // is not the CallInst itself, but the instruction which extends it.
511       Instruction* User = cast<Instruction>(*I->users().begin());
512       if (isa<SExtInst>(User) &&
513           User->getType()->getPrimitiveSizeInBits() == 128) {
514         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
515                           << *User << "\n");
516         Extend = User;
517         ResultTy = User->getType();
518         Unsigned = 0;
519       } else if (isa<ZExtInst>(User) &&
520                  User->getType()->getPrimitiveSizeInBits() == 128) {
521         LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
522                           << *ResultTy << "\n");
523         Extend = User;
524         ResultTy = User->getType();
525       }
526     }
527 
528     // If an extend hasn't been found and the type is an integer, create an
529     // extending gather and truncate back to the original type.
530     if (ResultTy->getPrimitiveSizeInBits() < 128 &&
531         ResultTy->isIntOrIntVectorTy()) {
532       ResultTy = ResultTy->getWithNewBitWidth(
533           128 / cast<FixedVectorType>(ResultTy)->getNumElements());
534       TruncResult = true;
535       LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
536                         << *ResultTy << "\n");
537     }
538 
539     // The final size of the gather must be a full vector width
540     if (ResultTy->getPrimitiveSizeInBits() != 128) {
541       LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
542                            "from the correct type. Expanding\n");
543       return nullptr;
544     }
545   }
546 
547   Value *Offsets;
548   int Scale;
549   Value *BasePtr = decomposePtr(
550       Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
551   if (!BasePtr)
552     return nullptr;
553 
554   Root = Extend;
555   Value *Mask = I->getArgOperand(2);
556   Instruction *Load = nullptr;
557   if (!match(Mask, m_One()))
558     Load = Builder.CreateIntrinsic(
559         Intrinsic::arm_mve_vldr_gather_offset_predicated,
560         {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
561         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
562          Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
563   else
564     Load = Builder.CreateIntrinsic(
565         Intrinsic::arm_mve_vldr_gather_offset,
566         {ResultTy, BasePtr->getType(), Offsets->getType()},
567         {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
568          Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
569 
570   if (TruncResult) {
571     Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
572     Builder.Insert(Load);
573   }
574   return Load;
575 }
576 
577 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
578   using namespace PatternMatch;
579   LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
580                     << *I << "\n");
581 
582   // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
583   // Attempt to turn the masked scatter in I into a MVE intrinsic
584   // Potentially optimising the addressing modes as we do so.
585   Value *Input = I->getArgOperand(0);
586   Value *Ptr = I->getArgOperand(1);
587   Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
588   auto *Ty = cast<FixedVectorType>(Input->getType());
589 
590   if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
591                                Alignment))
592     return nullptr;
593 
594   lookThroughBitcast(Ptr);
595   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
596 
597   IRBuilder<> Builder(I->getContext());
598   Builder.SetInsertPoint(I);
599   Builder.SetCurrentDebugLocation(I->getDebugLoc());
600 
601   Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
602   if (!Store)
603     Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
604   if (!Store)
605     Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
606   if (!Store)
607     return nullptr;
608 
609   LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
610                     << *Store << "\n");
611   I->eraseFromParent();
612   return Store;
613 }
614 
615 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
616     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
617   using namespace PatternMatch;
618   Value *Input = I->getArgOperand(0);
619   auto *Ty = cast<FixedVectorType>(Input->getType());
620   // Only QR variants allow truncating
621   if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
622     // Can't build an intrinsic for this
623     return nullptr;
624   }
625   Value *Mask = I->getArgOperand(3);
626   //  int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
627   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
628   if (match(Mask, m_One()))
629     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
630                                    {Ptr->getType(), Input->getType()},
631                                    {Ptr, Builder.getInt32(Increment), Input});
632   else
633     return Builder.CreateIntrinsic(
634         Intrinsic::arm_mve_vstr_scatter_base_predicated,
635         {Ptr->getType(), Input->getType(), Mask->getType()},
636         {Ptr, Builder.getInt32(Increment), Input, Mask});
637 }
638 
639 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
640     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
641   using namespace PatternMatch;
642   Value *Input = I->getArgOperand(0);
643   auto *Ty = cast<FixedVectorType>(Input->getType());
644   LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
645                     << "with writeback\n");
646   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
647     // Can't build an intrinsic for this
648     return nullptr;
649   Value *Mask = I->getArgOperand(3);
650   if (match(Mask, m_One()))
651     return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
652                                    {Ptr->getType(), Input->getType()},
653                                    {Ptr, Builder.getInt32(Increment), Input});
654   else
655     return Builder.CreateIntrinsic(
656         Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
657         {Ptr->getType(), Input->getType(), Mask->getType()},
658         {Ptr, Builder.getInt32(Increment), Input, Mask});
659 }
660 
661 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
662     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
663   using namespace PatternMatch;
664   Value *Input = I->getArgOperand(0);
665   Value *Mask = I->getArgOperand(3);
666   Type *InputTy = Input->getType();
667   Type *MemoryTy = InputTy;
668 
669   LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
670                     << " to base + vector of offsets\n");
671   // If the input has been truncated, try to integrate that trunc into the
672   // scatter instruction (we don't care about alignment here)
673   if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
674     Value *PreTrunc = Trunc->getOperand(0);
675     Type *PreTruncTy = PreTrunc->getType();
676     if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
677       Input = PreTrunc;
678       InputTy = PreTruncTy;
679     }
680   }
681   bool ExtendInput = false;
682   if (InputTy->getPrimitiveSizeInBits() < 128 &&
683       InputTy->isIntOrIntVectorTy()) {
684     // If we can't find a trunc to incorporate into the instruction, create an
685     // implicit one with a zext, so that we can still create a scatter. We know
686     // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
687     // smaller than 128 bits will divide evenly into a 128bit vector.
688     InputTy = InputTy->getWithNewBitWidth(
689         128 / cast<FixedVectorType>(InputTy)->getNumElements());
690     ExtendInput = true;
691     LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
692                       << *Input << "\n");
693   }
694   if (InputTy->getPrimitiveSizeInBits() != 128) {
695     LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
696                          "non-standard input types. Expanding.\n");
697     return nullptr;
698   }
699 
700   Value *Offsets;
701   int Scale;
702   Value *BasePtr = decomposePtr(
703       Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
704   if (!BasePtr)
705     return nullptr;
706 
707   if (ExtendInput)
708     Input = Builder.CreateZExt(Input, InputTy);
709   if (!match(Mask, m_One()))
710     return Builder.CreateIntrinsic(
711         Intrinsic::arm_mve_vstr_scatter_offset_predicated,
712         {BasePtr->getType(), Offsets->getType(), Input->getType(),
713          Mask->getType()},
714         {BasePtr, Offsets, Input,
715          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
716          Builder.getInt32(Scale), Mask});
717   else
718     return Builder.CreateIntrinsic(
719         Intrinsic::arm_mve_vstr_scatter_offset,
720         {BasePtr->getType(), Offsets->getType(), Input->getType()},
721         {BasePtr, Offsets, Input,
722          Builder.getInt32(MemoryTy->getScalarSizeInBits()),
723          Builder.getInt32(Scale)});
724 }
725 
726 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
727     IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
728   FixedVectorType *Ty;
729   if (I->getIntrinsicID() == Intrinsic::masked_gather)
730     Ty = cast<FixedVectorType>(I->getType());
731   else
732     Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
733 
734   // Incrementing gathers only exist for v4i32
735   if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
736     return nullptr;
737   // Incrementing gathers are not beneficial outside of a loop
738   Loop *L = LI->getLoopFor(I->getParent());
739   if (L == nullptr)
740     return nullptr;
741 
742   // Decompose the GEP into Base and Offsets
743   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
744   Value *Offsets;
745   Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
746   if (!BasePtr)
747     return nullptr;
748 
749   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
750                        "wb gather/scatter\n");
751 
752   // The gep was in charge of making sure the offsets are scaled correctly
753   // - calculate that factor so it can be applied by hand
754   int TypeScale =
755       computeScale(DL->getTypeSizeInBits(GEP->getSourceElementType()),
756                    DL->getTypeSizeInBits(GEP->getType()) /
757                        cast<FixedVectorType>(GEP->getType())->getNumElements());
758   if (TypeScale == -1)
759     return nullptr;
760 
761   if (GEP->hasOneUse()) {
762     // Only in this case do we want to build a wb gather, because the wb will
763     // change the phi which does affect other users of the gep (which will still
764     // be using the phi in the old way)
765     if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
766                                                     TypeScale, Builder))
767       return Load;
768   }
769 
770   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
771                        "non-wb gather/scatter\n");
772 
773   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
774   if (Add.first == nullptr)
775     return nullptr;
776   Value *OffsetsIncoming = Add.first;
777   int64_t Immediate = Add.second;
778 
779   // Make sure the offsets are scaled correctly
780   Instruction *ScaledOffsets = BinaryOperator::Create(
781       Instruction::Shl, OffsetsIncoming,
782       Builder.CreateVectorSplat(Ty->getNumElements(),
783                                 Builder.getInt32(TypeScale)),
784       "ScaledIndex", I->getIterator());
785   // Add the base to the offsets
786   OffsetsIncoming = BinaryOperator::Create(
787       Instruction::Add, ScaledOffsets,
788       Builder.CreateVectorSplat(
789           Ty->getNumElements(),
790           Builder.CreatePtrToInt(
791               BasePtr,
792               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
793       "StartIndex", I->getIterator());
794 
795   if (I->getIntrinsicID() == Intrinsic::masked_gather)
796     return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
797   else
798     return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
799 }
800 
801 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
802     IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
803     IRBuilder<> &Builder) {
804   // Check whether this gather's offset is incremented by a constant - if so,
805   // and the load is of the right type, we can merge this into a QI gather
806   Loop *L = LI->getLoopFor(I->getParent());
807   // Offsets that are worth merging into this instruction will be incremented
808   // by a constant, thus we're looking for an add of a phi and a constant
809   PHINode *Phi = dyn_cast<PHINode>(Offsets);
810   if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
811       Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
812     // No phi means no IV to write back to; if there is a phi, we expect it
813     // to have exactly two incoming values; the only phis we are interested in
814     // will be loop IV's and have exactly two uses, one in their increment and
815     // one in the gather's gep
816     return nullptr;
817 
818   unsigned IncrementIndex =
819       Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
820   // Look through the phi to the phi increment
821   Offsets = Phi->getIncomingValue(IncrementIndex);
822 
823   std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
824   if (Add.first == nullptr)
825     return nullptr;
826   Value *OffsetsIncoming = Add.first;
827   int64_t Immediate = Add.second;
828   if (OffsetsIncoming != Phi)
829     // Then the increment we are looking at is not an increment of the
830     // induction variable, and we don't want to do a writeback
831     return nullptr;
832 
833   Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
834   unsigned NumElems =
835       cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
836 
837   // Make sure the offsets are scaled correctly
838   Instruction *ScaledOffsets = BinaryOperator::Create(
839       Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
840       Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
841       "ScaledIndex",
842       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
843   // Add the base to the offsets
844   OffsetsIncoming = BinaryOperator::Create(
845       Instruction::Add, ScaledOffsets,
846       Builder.CreateVectorSplat(
847           NumElems,
848           Builder.CreatePtrToInt(
849               BasePtr,
850               cast<VectorType>(ScaledOffsets->getType())->getElementType())),
851       "StartIndex",
852       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
853   // The gather is pre-incrementing
854   OffsetsIncoming = BinaryOperator::Create(
855       Instruction::Sub, OffsetsIncoming,
856       Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
857       "PreIncrementStartIndex",
858       Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator());
859   Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
860 
861   Builder.SetInsertPoint(I);
862 
863   Instruction *EndResult;
864   Instruction *NewInduction;
865   if (I->getIntrinsicID() == Intrinsic::masked_gather) {
866     // Build the incrementing gather
867     Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
868     // One value to be handed to whoever uses the gather, one is the loop
869     // increment
870     EndResult = ExtractValueInst::Create(Load, 0, "Gather");
871     NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
872     Builder.Insert(EndResult);
873     Builder.Insert(NewInduction);
874   } else {
875     // Build the incrementing scatter
876     EndResult = NewInduction =
877         tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
878   }
879   Instruction *AddInst = cast<Instruction>(Offsets);
880   AddInst->replaceAllUsesWith(NewInduction);
881   AddInst->eraseFromParent();
882   Phi->setIncomingValue(IncrementIndex, NewInduction);
883 
884   return EndResult;
885 }
886 
887 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
888                                           Value *OffsSecondOperand,
889                                           unsigned StartIndex) {
890   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
891   BasicBlock::iterator InsertionPoint =
892       Phi->getIncomingBlock(StartIndex)->back().getIterator();
893   // Initialize the phi with a vector that contains a sum of the constants
894   Instruction *NewIndex = BinaryOperator::Create(
895       Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
896       "PushedOutAdd", InsertionPoint);
897   unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
898 
899   // Order such that start index comes first (this reduces mov's)
900   Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
901   Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
902                    Phi->getIncomingBlock(IncrementIndex));
903   Phi->removeIncomingValue(1);
904   Phi->removeIncomingValue((unsigned)0);
905 }
906 
907 void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
908                                              Value *IncrementPerRound,
909                                              Value *OffsSecondOperand,
910                                              unsigned LoopIncrement,
911                                              IRBuilder<> &Builder) {
912   LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
913 
914   // Create a new scalar add outside of the loop and transform it to a splat
915   // by which loop variable can be incremented
916   BasicBlock::iterator InsertionPoint =
917       Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator();
918 
919   // Create a new index
920   Value *StartIndex =
921       BinaryOperator::Create((Instruction::BinaryOps)Opcode,
922                              Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
923                              OffsSecondOperand, "PushedOutMul", InsertionPoint);
924 
925   Instruction *Product =
926       BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
927                              OffsSecondOperand, "Product", InsertionPoint);
928 
929   BasicBlock::iterator NewIncrInsertPt =
930       Phi->getIncomingBlock(LoopIncrement)->back().getIterator();
931   NewIncrInsertPt = std::prev(NewIncrInsertPt);
932 
933   // Increment NewIndex by Product instead of the multiplication
934   Instruction *NewIncrement = BinaryOperator::Create(
935       Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt);
936 
937   Phi->addIncoming(StartIndex,
938                    Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
939   Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
940   Phi->removeIncomingValue((unsigned)0);
941   Phi->removeIncomingValue((unsigned)0);
942 }
943 
944 // Check whether all usages of this instruction are as offsets of
945 // gathers/scatters or simple arithmetics only used by gathers/scatters
946 static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) {
947   if (I->hasNUses(0)) {
948     return false;
949   }
950   bool Gatscat = true;
951   for (User *U : I->users()) {
952     if (!isa<Instruction>(U))
953       return false;
954     if (isa<GetElementPtrInst>(U) ||
955         isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
956       return Gatscat;
957     } else {
958       unsigned OpCode = cast<Instruction>(U)->getOpcode();
959       if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
960            OpCode == Instruction::Shl ||
961            isAddLikeOr(cast<Instruction>(U), DL)) &&
962           hasAllGatScatUsers(cast<Instruction>(U), DL)) {
963         continue;
964       }
965       return false;
966     }
967   }
968   return Gatscat;
969 }
970 
971 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
972                                                LoopInfo *LI) {
973   LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
974                     << *Offsets << "\n");
975   // Optimise the addresses of gathers/scatters by moving invariant
976   // calculations out of the loop
977   if (!isa<Instruction>(Offsets))
978     return false;
979   Instruction *Offs = cast<Instruction>(Offsets);
980   if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
981       Offs->getOpcode() != Instruction::Mul &&
982       Offs->getOpcode() != Instruction::Shl)
983     return false;
984   Loop *L = LI->getLoopFor(BB);
985   if (L == nullptr)
986     return false;
987   if (!Offs->hasOneUse()) {
988     if (!hasAllGatScatUsers(Offs, *DL))
989       return false;
990   }
991 
992   // Find out which, if any, operand of the instruction
993   // is a phi node
994   PHINode *Phi;
995   int OffsSecondOp;
996   if (isa<PHINode>(Offs->getOperand(0))) {
997     Phi = cast<PHINode>(Offs->getOperand(0));
998     OffsSecondOp = 1;
999   } else if (isa<PHINode>(Offs->getOperand(1))) {
1000     Phi = cast<PHINode>(Offs->getOperand(1));
1001     OffsSecondOp = 0;
1002   } else {
1003     bool Changed = false;
1004     if (isa<Instruction>(Offs->getOperand(0)) &&
1005         L->contains(cast<Instruction>(Offs->getOperand(0))))
1006       Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
1007     if (isa<Instruction>(Offs->getOperand(1)) &&
1008         L->contains(cast<Instruction>(Offs->getOperand(1))))
1009       Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
1010     if (!Changed)
1011       return false;
1012     if (isa<PHINode>(Offs->getOperand(0))) {
1013       Phi = cast<PHINode>(Offs->getOperand(0));
1014       OffsSecondOp = 1;
1015     } else if (isa<PHINode>(Offs->getOperand(1))) {
1016       Phi = cast<PHINode>(Offs->getOperand(1));
1017       OffsSecondOp = 0;
1018     } else {
1019       return false;
1020     }
1021   }
1022   // A phi node we want to perform this function on should be from the
1023   // loop header.
1024   if (Phi->getParent() != L->getHeader())
1025     return false;
1026 
1027   // We're looking for a simple add recurrence.
1028   BinaryOperator *IncInstruction;
1029   Value *Start, *IncrementPerRound;
1030   if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1031       IncInstruction->getOpcode() != Instruction::Add)
1032     return false;
1033 
1034   int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1035 
1036   // Get the value that is added to/multiplied with the phi
1037   Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1038 
1039   if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1040       !L->isLoopInvariant(OffsSecondOperand))
1041     // Something has gone wrong, abort
1042     return false;
1043 
1044   // Only proceed if the increment per round is a constant or an instruction
1045   // which does not originate from within the loop
1046   if (!isa<Constant>(IncrementPerRound) &&
1047       !(isa<Instruction>(IncrementPerRound) &&
1048         !L->contains(cast<Instruction>(IncrementPerRound))))
1049     return false;
1050 
1051   // If the phi is not used by anything else, we can just adapt it when
1052   // replacing the instruction; if it is, we'll have to duplicate it
1053   PHINode *NewPhi;
1054   if (Phi->getNumUses() == 2) {
1055     // No other users -> reuse existing phi (One user is the instruction
1056     // we're looking at, the other is the phi increment)
1057     if (IncInstruction->getNumUses() != 1) {
1058       // If the incrementing instruction does have more users than
1059       // our phi, we need to copy it
1060       IncInstruction = BinaryOperator::Create(
1061           Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1062           IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1063       Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1064     }
1065     NewPhi = Phi;
1066   } else {
1067     // There are other users -> create a new phi
1068     NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi->getIterator());
1069     // Copy the incoming values of the old phi
1070     NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1071                         Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1072     IncInstruction = BinaryOperator::Create(
1073         Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1074         IncrementPerRound, "LoopIncrement", IncInstruction->getIterator());
1075     NewPhi->addIncoming(IncInstruction,
1076                         Phi->getIncomingBlock(IncrementingBlock));
1077     IncrementingBlock = 1;
1078   }
1079 
1080   IRBuilder<> Builder(BB->getContext());
1081   Builder.SetInsertPoint(Phi);
1082   Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1083 
1084   switch (Offs->getOpcode()) {
1085   case Instruction::Add:
1086   case Instruction::Or:
1087     pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1088     break;
1089   case Instruction::Mul:
1090   case Instruction::Shl:
1091     pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1092                   OffsSecondOperand, IncrementingBlock, Builder);
1093     break;
1094   default:
1095     return false;
1096   }
1097   LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1098                     << "add/mul\n");
1099 
1100   // The instruction has now been "absorbed" into the phi value
1101   Offs->replaceAllUsesWith(NewPhi);
1102   if (Offs->hasNUses(0))
1103     Offs->eraseFromParent();
1104   // Clean up the old increment in case it's unused because we built a new
1105   // one
1106   if (IncInstruction->hasNUses(0))
1107     IncInstruction->eraseFromParent();
1108 
1109   return true;
1110 }
1111 
1112 static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
1113                                       unsigned ScaleY, IRBuilder<> &Builder) {
1114   // Splat the non-vector value to a vector of the given type - if the value is
1115   // a constant (and its value isn't too big), we can even use this opportunity
1116   // to scale it to the size of the vector elements
1117   auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1118     ConstantInt *Const;
1119     if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1120         VT->getElementType() != NonVectorVal->getType()) {
1121       unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1122       uint64_t N = Const->getZExtValue();
1123       if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1124         NonVectorVal = Builder.CreateVectorSplat(
1125             VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1126         return;
1127       }
1128     }
1129     NonVectorVal =
1130         Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1131   };
1132 
1133   FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1134   FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1135   // If one of X, Y is not a vector, we have to splat it in order
1136   // to add the two of them.
1137   if (XElType && !YElType) {
1138     FixSummands(XElType, Y);
1139     YElType = cast<FixedVectorType>(Y->getType());
1140   } else if (YElType && !XElType) {
1141     FixSummands(YElType, X);
1142     XElType = cast<FixedVectorType>(X->getType());
1143   }
1144   assert(XElType && YElType && "Unknown vector types");
1145   // Check that the summands are of compatible types
1146   if (XElType != YElType) {
1147     LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1148     return nullptr;
1149   }
1150 
1151   if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1152     // Check that by adding the vectors we do not accidentally
1153     // create an overflow
1154     Constant *ConstX = dyn_cast<Constant>(X);
1155     Constant *ConstY = dyn_cast<Constant>(Y);
1156     if (!ConstX || !ConstY)
1157       return nullptr;
1158     unsigned TargetElemSize = 128 / XElType->getNumElements();
1159     for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1160       ConstantInt *ConstXEl =
1161           dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1162       ConstantInt *ConstYEl =
1163           dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1164       if (!ConstXEl || !ConstYEl ||
1165           ConstXEl->getZExtValue() * ScaleX +
1166                   ConstYEl->getZExtValue() * ScaleY >=
1167               (unsigned)(1 << (TargetElemSize - 1)))
1168         return nullptr;
1169     }
1170   }
1171 
1172   Value *XScale = Builder.CreateVectorSplat(
1173       XElType->getNumElements(),
1174       Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
1175   Value *YScale = Builder.CreateVectorSplat(
1176       YElType->getNumElements(),
1177       Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
1178   Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
1179                                  Builder.CreateMul(Y, YScale));
1180 
1181   if (checkOffsetSize(Add, XElType->getNumElements()))
1182     return Add;
1183   else
1184     return nullptr;
1185 }
1186 
1187 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1188                                          Value *&Offsets, unsigned &Scale,
1189                                          IRBuilder<> &Builder) {
1190   Value *GEPPtr = GEP->getPointerOperand();
1191   Offsets = GEP->getOperand(1);
1192   Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
1193   // We only merge geps with constant offsets, because only for those
1194   // we can make sure that we do not cause an overflow
1195   if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
1196     return nullptr;
1197   if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
1198     // Merge the two geps into one
1199     Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
1200     if (!BaseBasePtr)
1201       return nullptr;
1202     Offsets = CheckAndCreateOffsetAdd(
1203         Offsets, Scale, GEP->getOperand(1),
1204         DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
1205     if (Offsets == nullptr)
1206       return nullptr;
1207     Scale = 1; // Scale is always an i8 at this point.
1208     return BaseBasePtr;
1209   }
1210   return GEPPtr;
1211 }
1212 
1213 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1214                                                LoopInfo *LI) {
1215   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1216   if (!GEP)
1217     return false;
1218   bool Changed = false;
1219   if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
1220     IRBuilder<> Builder(GEP->getContext());
1221     Builder.SetInsertPoint(GEP);
1222     Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1223     Value *Offsets;
1224     unsigned Scale;
1225     Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1226     // We only want to merge the geps if there is a real chance that they can be
1227     // used by an MVE gather; thus the offset has to have the correct size
1228     // (always i32 if it is not of vector type) and the base has to be a
1229     // pointer.
1230     if (Offsets && Base && Base != GEP) {
1231       assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
1232       Type *BaseTy = Builder.getPtrTy();
1233       if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
1234         BaseTy = FixedVectorType::get(BaseTy, VecTy);
1235       GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1236           Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
1237           "gep.merged", GEP->getIterator());
1238       LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
1239                         << "\n      new :  " << *NewAddress << "\n");
1240       GEP->replaceAllUsesWith(
1241           Builder.CreateBitCast(NewAddress, GEP->getType()));
1242       GEP = NewAddress;
1243       Changed = true;
1244     }
1245   }
1246   Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1247   return Changed;
1248 }
1249 
1250 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1251   if (!EnableMaskedGatherScatters)
1252     return false;
1253   auto &TPC = getAnalysis<TargetPassConfig>();
1254   auto &TM = TPC.getTM<TargetMachine>();
1255   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1256   if (!ST->hasMVEIntegerOps())
1257     return false;
1258   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1259   DL = &F.getDataLayout();
1260   SmallVector<IntrinsicInst *, 4> Gathers;
1261   SmallVector<IntrinsicInst *, 4> Scatters;
1262 
1263   bool Changed = false;
1264 
1265   for (BasicBlock &BB : F) {
1266     Changed |= SimplifyInstructionsInBlock(&BB);
1267 
1268     for (Instruction &I : BB) {
1269       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1270       if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1271           isa<FixedVectorType>(II->getType())) {
1272         Gathers.push_back(II);
1273         Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1274       } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1275                  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1276         Scatters.push_back(II);
1277         Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1278       }
1279     }
1280   }
1281   for (IntrinsicInst *I : Gathers) {
1282     Instruction *L = lowerGather(I);
1283     if (L == nullptr)
1284       continue;
1285 
1286     // Get rid of any now dead instructions
1287     SimplifyInstructionsInBlock(L->getParent());
1288     Changed = true;
1289   }
1290 
1291   for (IntrinsicInst *I : Scatters) {
1292     Instruction *S = lowerScatter(I);
1293     if (S == nullptr)
1294       continue;
1295 
1296     // Get rid of any now dead instructions
1297     SimplifyInstructionsInBlock(S->getParent());
1298     Changed = true;
1299   }
1300   return Changed;
1301 }
1302