xref: /openbsd-src/gnu/llvm/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1 //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass merges loads/stores to/from sequential memory addresses into vector
10 // loads/stores.  Although there's nothing GPU-specific in here, this pass is
11 // motivated by the microarchitectural quirks of nVidia and AMD GPUs.
12 //
13 // (For simplicity below we talk about loads only, but everything also applies
14 // to stores.)
15 //
16 // This pass is intended to be run late in the pipeline, after other
17 // vectorization opportunities have been exploited.  So the assumption here is
18 // that immediately following our new vector load we'll need to extract out the
19 // individual elements of the load, so we can operate on them individually.
20 //
21 // On CPUs this transformation is usually not beneficial, because extracting the
22 // elements of a vector register is expensive on most architectures.  It's
23 // usually better just to load each element individually into its own scalar
24 // register.
25 //
26 // However, nVidia and AMD GPUs don't have proper vector registers.  Instead, a
27 // "vector load" loads directly into a series of scalar registers.  In effect,
28 // extracting the elements of the vector is free.  It's therefore always
29 // beneficial to vectorize a sequence of loads on these architectures.
30 //
31 // Vectorizing (perhaps a better name might be "coalescing") loads can have
32 // large performance impacts on GPU kernels, and opportunities for vectorizing
33 // are common in GPU code.  This pass tries very hard to find such
34 // opportunities; its runtime is quadratic in the number of loads in a BB.
35 //
36 // Some CPU architectures, such as ARM, have instructions that load into
37 // multiple scalar registers, similar to a GPU vectorized load.  In theory ARM
38 // could use this pass (with some modifications), but currently it implements
39 // its own pass to do something similar to what we do here.
40 
41 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
42 #include "llvm/ADT/APInt.h"
43 #include "llvm/ADT/ArrayRef.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/PostOrderIterator.h"
46 #include "llvm/ADT/STLExtras.h"
47 #include "llvm/ADT/SmallPtrSet.h"
48 #include "llvm/ADT/SmallVector.h"
49 #include "llvm/ADT/Statistic.h"
50 #include "llvm/ADT/iterator_range.h"
51 #include "llvm/Analysis/AliasAnalysis.h"
52 #include "llvm/Analysis/AssumptionCache.h"
53 #include "llvm/Analysis/MemoryLocation.h"
54 #include "llvm/Analysis/ScalarEvolution.h"
55 #include "llvm/Analysis/TargetTransformInfo.h"
56 #include "llvm/Analysis/ValueTracking.h"
57 #include "llvm/Analysis/VectorUtils.h"
58 #include "llvm/IR/Attributes.h"
59 #include "llvm/IR/BasicBlock.h"
60 #include "llvm/IR/Constants.h"
61 #include "llvm/IR/DataLayout.h"
62 #include "llvm/IR/DerivedTypes.h"
63 #include "llvm/IR/Dominators.h"
64 #include "llvm/IR/Function.h"
65 #include "llvm/IR/GetElementPtrTypeIterator.h"
66 #include "llvm/IR/IRBuilder.h"
67 #include "llvm/IR/InstrTypes.h"
68 #include "llvm/IR/Instruction.h"
69 #include "llvm/IR/Instructions.h"
70 #include "llvm/IR/Module.h"
71 #include "llvm/IR/Type.h"
72 #include "llvm/IR/Value.h"
73 #include "llvm/InitializePasses.h"
74 #include "llvm/Pass.h"
75 #include "llvm/Support/Casting.h"
76 #include "llvm/Support/Debug.h"
77 #include "llvm/Support/KnownBits.h"
78 #include "llvm/Support/MathExtras.h"
79 #include "llvm/Support/raw_ostream.h"
80 #include "llvm/Transforms/Utils/Local.h"
81 #include "llvm/Transforms/Vectorize.h"
82 #include <algorithm>
83 #include <cassert>
84 #include <cstdlib>
85 #include <tuple>
86 #include <utility>
87 
88 using namespace llvm;
89 
90 #define DEBUG_TYPE "load-store-vectorizer"
91 
92 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
93 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
94 
95 // FIXME: Assuming stack alignment of 4 is always good enough
96 static const unsigned StackAdjustedAlignment = 4;
97 
98 namespace {
99 
100 /// ChainID is an arbitrary token that is allowed to be different only for the
101 /// accesses that are guaranteed to be considered non-consecutive by
102 /// Vectorizer::isConsecutiveAccess. It's used for grouping instructions
103 /// together and reducing the number of instructions the main search operates on
104 /// at a time, i.e. this is to reduce compile time and nothing else as the main
105 /// search has O(n^2) time complexity. The underlying type of ChainID should not
106 /// be relied upon.
107 using ChainID = const Value *;
108 using InstrList = SmallVector<Instruction *, 8>;
109 using InstrListMap = MapVector<ChainID, InstrList>;
110 
111 class Vectorizer {
112   Function &F;
113   AliasAnalysis &AA;
114   AssumptionCache &AC;
115   DominatorTree &DT;
116   ScalarEvolution &SE;
117   TargetTransformInfo &TTI;
118   const DataLayout &DL;
119   IRBuilder<> Builder;
120 
121 public:
Vectorizer(Function & F,AliasAnalysis & AA,AssumptionCache & AC,DominatorTree & DT,ScalarEvolution & SE,TargetTransformInfo & TTI)122   Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
123              DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
124       : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI),
125         DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {}
126 
127   bool run();
128 
129 private:
130   unsigned getPointerAddressSpace(Value *I);
131 
132   static const unsigned MaxDepth = 3;
133 
134   bool isConsecutiveAccess(Value *A, Value *B);
135   bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta,
136                               unsigned Depth = 0) const;
137   bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta,
138                                    unsigned Depth) const;
139   bool lookThroughSelects(Value *PtrA, Value *PtrB, const APInt &PtrDelta,
140                           unsigned Depth) const;
141 
142   /// After vectorization, reorder the instructions that I depends on
143   /// (the instructions defining its operands), to ensure they dominate I.
144   void reorder(Instruction *I);
145 
146   /// Returns the first and the last instructions in Chain.
147   std::pair<BasicBlock::iterator, BasicBlock::iterator>
148   getBoundaryInstrs(ArrayRef<Instruction *> Chain);
149 
150   /// Erases the original instructions after vectorizing.
151   void eraseInstructions(ArrayRef<Instruction *> Chain);
152 
153   /// "Legalize" the vector type that would be produced by combining \p
154   /// ElementSizeBits elements in \p Chain. Break into two pieces such that the
155   /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is
156   /// expected to have more than 4 elements.
157   std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
158   splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits);
159 
160   /// Finds the largest prefix of Chain that's vectorizable, checking for
161   /// intervening instructions which may affect the memory accessed by the
162   /// instructions within Chain.
163   ///
164   /// The elements of \p Chain must be all loads or all stores and must be in
165   /// address order.
166   ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain);
167 
168   /// Collects load and store instructions to vectorize.
169   std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB);
170 
171   /// Processes the collected instructions, the \p Map. The values of \p Map
172   /// should be all loads or all stores.
173   bool vectorizeChains(InstrListMap &Map);
174 
175   /// Finds the load/stores to consecutive memory addresses and vectorizes them.
176   bool vectorizeInstructions(ArrayRef<Instruction *> Instrs);
177 
178   /// Vectorizes the load instructions in Chain.
179   bool
180   vectorizeLoadChain(ArrayRef<Instruction *> Chain,
181                      SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
182 
183   /// Vectorizes the store instructions in Chain.
184   bool
185   vectorizeStoreChain(ArrayRef<Instruction *> Chain,
186                       SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
187 
188   /// Check if this load/store access is misaligned accesses.
189   /// Returns a \p RelativeSpeed of an operation if allowed suitable to
190   /// compare to another result for the same \p AddressSpace and potentially
191   /// different \p Alignment and \p SzInBytes.
192   bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
193                           Align Alignment, unsigned &RelativeSpeed);
194 };
195 
196 class LoadStoreVectorizerLegacyPass : public FunctionPass {
197 public:
198   static char ID;
199 
LoadStoreVectorizerLegacyPass()200   LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {
201     initializeLoadStoreVectorizerLegacyPassPass(*PassRegistry::getPassRegistry());
202   }
203 
204   bool runOnFunction(Function &F) override;
205 
getPassName() const206   StringRef getPassName() const override {
207     return "GPU Load and Store Vectorizer";
208   }
209 
getAnalysisUsage(AnalysisUsage & AU) const210   void getAnalysisUsage(AnalysisUsage &AU) const override {
211     AU.addRequired<AAResultsWrapperPass>();
212     AU.addRequired<AssumptionCacheTracker>();
213     AU.addRequired<ScalarEvolutionWrapperPass>();
214     AU.addRequired<DominatorTreeWrapperPass>();
215     AU.addRequired<TargetTransformInfoWrapperPass>();
216     AU.setPreservesCFG();
217   }
218 };
219 
220 } // end anonymous namespace
221 
222 char LoadStoreVectorizerLegacyPass::ID = 0;
223 
224 INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
225                       "Vectorize load and Store instructions", false, false)
226 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
227 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker);
228 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)229 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
230 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
231 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
232 INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
233                     "Vectorize load and store instructions", false, false)
234 
235 Pass *llvm::createLoadStoreVectorizerPass() {
236   return new LoadStoreVectorizerLegacyPass();
237 }
238 
runOnFunction(Function & F)239 bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
240   // Don't vectorize when the attribute NoImplicitFloat is used.
241   if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
242     return false;
243 
244   AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
245   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
246   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
247   TargetTransformInfo &TTI =
248       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
249 
250   AssumptionCache &AC =
251       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
252 
253   Vectorizer V(F, AA, AC, DT, SE, TTI);
254   return V.run();
255 }
256 
run(Function & F,FunctionAnalysisManager & AM)257 PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) {
258   // Don't vectorize when the attribute NoImplicitFloat is used.
259   if (F.hasFnAttribute(Attribute::NoImplicitFloat))
260     return PreservedAnalyses::all();
261 
262   AliasAnalysis &AA = AM.getResult<AAManager>(F);
263   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
264   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
265   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
266   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
267 
268   Vectorizer V(F, AA, AC, DT, SE, TTI);
269   bool Changed = V.run();
270   PreservedAnalyses PA;
271   PA.preserveSet<CFGAnalyses>();
272   return Changed ? PA : PreservedAnalyses::all();
273 }
274 
275 // The real propagateMetadata expects a SmallVector<Value*>, but we deal in
276 // vectors of Instructions.
propagateMetadata(Instruction * I,ArrayRef<Instruction * > IL)277 static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) {
278   SmallVector<Value *, 8> VL(IL.begin(), IL.end());
279   propagateMetadata(I, VL);
280 }
281 
282 // Vectorizer Implementation
run()283 bool Vectorizer::run() {
284   bool Changed = false;
285 
286   // Scan the blocks in the function in post order.
287   for (BasicBlock *BB : post_order(&F)) {
288     InstrListMap LoadRefs, StoreRefs;
289     std::tie(LoadRefs, StoreRefs) = collectInstructions(BB);
290     Changed |= vectorizeChains(LoadRefs);
291     Changed |= vectorizeChains(StoreRefs);
292   }
293 
294   return Changed;
295 }
296 
getPointerAddressSpace(Value * I)297 unsigned Vectorizer::getPointerAddressSpace(Value *I) {
298   if (LoadInst *L = dyn_cast<LoadInst>(I))
299     return L->getPointerAddressSpace();
300   if (StoreInst *S = dyn_cast<StoreInst>(I))
301     return S->getPointerAddressSpace();
302   return -1;
303 }
304 
305 // FIXME: Merge with llvm::isConsecutiveAccess
isConsecutiveAccess(Value * A,Value * B)306 bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
307   Value *PtrA = getLoadStorePointerOperand(A);
308   Value *PtrB = getLoadStorePointerOperand(B);
309   unsigned ASA = getPointerAddressSpace(A);
310   unsigned ASB = getPointerAddressSpace(B);
311 
312   // Check that the address spaces match and that the pointers are valid.
313   if (!PtrA || !PtrB || (ASA != ASB))
314     return false;
315 
316   // Make sure that A and B are different pointers of the same size type.
317   Type *PtrATy = getLoadStoreType(A);
318   Type *PtrBTy = getLoadStoreType(B);
319   if (PtrA == PtrB ||
320       PtrATy->isVectorTy() != PtrBTy->isVectorTy() ||
321       DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) ||
322       DL.getTypeStoreSize(PtrATy->getScalarType()) !=
323           DL.getTypeStoreSize(PtrBTy->getScalarType()))
324     return false;
325 
326   unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA);
327   APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy));
328 
329   return areConsecutivePointers(PtrA, PtrB, Size);
330 }
331 
areConsecutivePointers(Value * PtrA,Value * PtrB,APInt PtrDelta,unsigned Depth) const332 bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB,
333                                         APInt PtrDelta, unsigned Depth) const {
334   unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType());
335   APInt OffsetA(PtrBitWidth, 0);
336   APInt OffsetB(PtrBitWidth, 0);
337   PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
338   PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
339 
340   unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
341 
342   if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
343     return false;
344 
345   // In case if we have to shrink the pointer
346   // stripAndAccumulateInBoundsConstantOffsets should properly handle a
347   // possible overflow and the value should fit into a smallest data type
348   // used in the cast/gep chain.
349   assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth &&
350          OffsetB.getMinSignedBits() <= NewPtrBitWidth);
351 
352   OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
353   OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
354   PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth);
355 
356   APInt OffsetDelta = OffsetB - OffsetA;
357 
358   // Check if they are based on the same pointer. That makes the offsets
359   // sufficient.
360   if (PtrA == PtrB)
361     return OffsetDelta == PtrDelta;
362 
363   // Compute the necessary base pointer delta to have the necessary final delta
364   // equal to the pointer delta requested.
365   APInt BaseDelta = PtrDelta - OffsetDelta;
366 
367   // Compute the distance with SCEV between the base pointers.
368   const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
369   const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
370   const SCEV *C = SE.getConstant(BaseDelta);
371   const SCEV *X = SE.getAddExpr(PtrSCEVA, C);
372   if (X == PtrSCEVB)
373     return true;
374 
375   // The above check will not catch the cases where one of the pointers is
376   // factorized but the other one is not, such as (C + (S * (A + B))) vs
377   // (AS + BS). Get the minus scev. That will allow re-combining the expresions
378   // and getting the simplified difference.
379   const SCEV *Dist = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA);
380   if (C == Dist)
381     return true;
382 
383   // Sometimes even this doesn't work, because SCEV can't always see through
384   // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
385   // things the hard way.
386   return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth);
387 }
388 
checkNoWrapFlags(Instruction * I,bool Signed)389 static bool checkNoWrapFlags(Instruction *I, bool Signed) {
390   BinaryOperator *BinOpI = cast<BinaryOperator>(I);
391   return (Signed && BinOpI->hasNoSignedWrap()) ||
392          (!Signed && BinOpI->hasNoUnsignedWrap());
393 }
394 
checkIfSafeAddSequence(const APInt & IdxDiff,Instruction * AddOpA,unsigned MatchingOpIdxA,Instruction * AddOpB,unsigned MatchingOpIdxB,bool Signed)395 static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
396                                    unsigned MatchingOpIdxA, Instruction *AddOpB,
397                                    unsigned MatchingOpIdxB, bool Signed) {
398   // If both OpA and OpB is an add with NSW/NUW and with
399   // one of the operands being the same, we can guarantee that the
400   // transformation is safe if we can prove that OpA won't overflow when
401   // IdxDiff added to the other operand of OpA.
402   // For example:
403   //  %tmp7 = add nsw i32 %tmp2, %v0
404   //  %tmp8 = sext i32 %tmp7 to i64
405   //  ...
406   //  %tmp11 = add nsw i32 %v0, 1
407   //  %tmp12 = add nsw i32 %tmp2, %tmp11
408   //  %tmp13 = sext i32 %tmp12 to i64
409   //
410   //  Both %tmp7 and %tmp2 has the nsw flag and the first operand
411   //  is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
412   //  because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
413   //  nsw flag.
414   assert(AddOpA->getOpcode() == Instruction::Add &&
415          AddOpB->getOpcode() == Instruction::Add &&
416          checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
417   if (AddOpA->getOperand(MatchingOpIdxA) ==
418       AddOpB->getOperand(MatchingOpIdxB)) {
419     Value *OtherOperandA = AddOpA->getOperand(MatchingOpIdxA == 1 ? 0 : 1);
420     Value *OtherOperandB = AddOpB->getOperand(MatchingOpIdxB == 1 ? 0 : 1);
421     Instruction *OtherInstrA = dyn_cast<Instruction>(OtherOperandA);
422     Instruction *OtherInstrB = dyn_cast<Instruction>(OtherOperandB);
423     // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
424     if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add &&
425         checkNoWrapFlags(OtherInstrB, Signed) &&
426         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
427       int64_t CstVal =
428           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
429       if (OtherInstrB->getOperand(0) == OtherOperandA &&
430           IdxDiff.getSExtValue() == CstVal)
431         return true;
432     }
433     // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
434     if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add &&
435         checkNoWrapFlags(OtherInstrA, Signed) &&
436         isa<ConstantInt>(OtherInstrA->getOperand(1))) {
437       int64_t CstVal =
438           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
439       if (OtherInstrA->getOperand(0) == OtherOperandB &&
440           IdxDiff.getSExtValue() == -CstVal)
441         return true;
442     }
443     // Match `x +nsw/nuw (y +nsw/nuw c)` and
444     // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
445     if (OtherInstrA && OtherInstrB &&
446         OtherInstrA->getOpcode() == Instruction::Add &&
447         OtherInstrB->getOpcode() == Instruction::Add &&
448         checkNoWrapFlags(OtherInstrA, Signed) &&
449         checkNoWrapFlags(OtherInstrB, Signed) &&
450         isa<ConstantInt>(OtherInstrA->getOperand(1)) &&
451         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
452       int64_t CstValA =
453           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
454       int64_t CstValB =
455           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
456       if (OtherInstrA->getOperand(0) == OtherInstrB->getOperand(0) &&
457           IdxDiff.getSExtValue() == (CstValB - CstValA))
458         return true;
459     }
460   }
461   return false;
462 }
463 
lookThroughComplexAddresses(Value * PtrA,Value * PtrB,APInt PtrDelta,unsigned Depth) const464 bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
465                                              APInt PtrDelta,
466                                              unsigned Depth) const {
467   auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
468   auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
469   if (!GEPA || !GEPB)
470     return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth);
471 
472   // Look through GEPs after checking they're the same except for the last
473   // index.
474   if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
475       GEPA->getPointerOperand() != GEPB->getPointerOperand())
476     return false;
477   gep_type_iterator GTIA = gep_type_begin(GEPA);
478   gep_type_iterator GTIB = gep_type_begin(GEPB);
479   for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
480     if (GTIA.getOperand() != GTIB.getOperand())
481       return false;
482     ++GTIA;
483     ++GTIB;
484   }
485 
486   Instruction *OpA = dyn_cast<Instruction>(GTIA.getOperand());
487   Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand());
488   if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
489       OpA->getType() != OpB->getType())
490     return false;
491 
492   if (PtrDelta.isNegative()) {
493     if (PtrDelta.isMinSignedValue())
494       return false;
495     PtrDelta.negate();
496     std::swap(OpA, OpB);
497   }
498   uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType());
499   if (PtrDelta.urem(Stride) != 0)
500     return false;
501   unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits();
502   APInt IdxDiff = PtrDelta.udiv(Stride).zext(IdxBitWidth);
503 
504   // Only look through a ZExt/SExt.
505   if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
506     return false;
507 
508   bool Signed = isa<SExtInst>(OpA);
509 
510   // At this point A could be a function parameter, i.e. not an instruction
511   Value *ValA = OpA->getOperand(0);
512   OpB = dyn_cast<Instruction>(OpB->getOperand(0));
513   if (!OpB || ValA->getType() != OpB->getType())
514     return false;
515 
516   // Now we need to prove that adding IdxDiff to ValA won't overflow.
517   bool Safe = false;
518 
519   // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
520   // ValA, we're okay.
521   if (OpB->getOpcode() == Instruction::Add &&
522       isa<ConstantInt>(OpB->getOperand(1)) &&
523       IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
524       checkNoWrapFlags(OpB, Signed))
525     Safe = true;
526 
527   // Second attempt: check if we have eligible add NSW/NUW instruction
528   // sequences.
529   OpA = dyn_cast<Instruction>(ValA);
530   if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
531       OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) &&
532       checkNoWrapFlags(OpB, Signed)) {
533     // In the checks below a matching operand in OpA and OpB is
534     // an operand which is the same in those two instructions.
535     // Below we account for possible orders of the operands of
536     // these add instructions.
537     for (unsigned MatchingOpIdxA : {0, 1})
538       for (unsigned MatchingOpIdxB : {0, 1})
539         if (!Safe)
540           Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOpIdxA, OpB,
541                                         MatchingOpIdxB, Signed);
542   }
543 
544   unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
545 
546   // Third attempt:
547   // If all set bits of IdxDiff or any higher order bit other than the sign bit
548   // are known to be zero in ValA, we can add Diff to it while guaranteeing no
549   // overflow of any sort.
550   if (!Safe) {
551     KnownBits Known(BitWidth);
552     computeKnownBits(ValA, Known, DL, 0, &AC, OpB, &DT);
553     APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth());
554     if (Signed)
555       BitsAllowedToBeSet.clearBit(BitWidth - 1);
556     if (BitsAllowedToBeSet.ult(IdxDiff))
557       return false;
558   }
559 
560   const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
561   const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
562   const SCEV *C = SE.getConstant(IdxDiff.trunc(BitWidth));
563   const SCEV *X = SE.getAddExpr(OffsetSCEVA, C);
564   return X == OffsetSCEVB;
565 }
566 
lookThroughSelects(Value * PtrA,Value * PtrB,const APInt & PtrDelta,unsigned Depth) const567 bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB,
568                                     const APInt &PtrDelta,
569                                     unsigned Depth) const {
570   if (Depth++ == MaxDepth)
571     return false;
572 
573   if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
574     if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
575       return SelectA->getCondition() == SelectB->getCondition() &&
576              areConsecutivePointers(SelectA->getTrueValue(),
577                                     SelectB->getTrueValue(), PtrDelta, Depth) &&
578              areConsecutivePointers(SelectA->getFalseValue(),
579                                     SelectB->getFalseValue(), PtrDelta, Depth);
580     }
581   }
582   return false;
583 }
584 
reorder(Instruction * I)585 void Vectorizer::reorder(Instruction *I) {
586   SmallPtrSet<Instruction *, 16> InstructionsToMove;
587   SmallVector<Instruction *, 16> Worklist;
588 
589   Worklist.push_back(I);
590   while (!Worklist.empty()) {
591     Instruction *IW = Worklist.pop_back_val();
592     int NumOperands = IW->getNumOperands();
593     for (int i = 0; i < NumOperands; i++) {
594       Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
595       if (!IM || IM->getOpcode() == Instruction::PHI)
596         continue;
597 
598       // If IM is in another BB, no need to move it, because this pass only
599       // vectorizes instructions within one BB.
600       if (IM->getParent() != I->getParent())
601         continue;
602 
603       if (!IM->comesBefore(I)) {
604         InstructionsToMove.insert(IM);
605         Worklist.push_back(IM);
606       }
607     }
608   }
609 
610   // All instructions to move should follow I. Start from I, not from begin().
611   for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;
612        ++BBI) {
613     if (!InstructionsToMove.count(&*BBI))
614       continue;
615     Instruction *IM = &*BBI;
616     --BBI;
617     IM->removeFromParent();
618     IM->insertBefore(I);
619   }
620 }
621 
622 std::pair<BasicBlock::iterator, BasicBlock::iterator>
getBoundaryInstrs(ArrayRef<Instruction * > Chain)623 Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) {
624   Instruction *C0 = Chain[0];
625   BasicBlock::iterator FirstInstr = C0->getIterator();
626   BasicBlock::iterator LastInstr = C0->getIterator();
627 
628   BasicBlock *BB = C0->getParent();
629   unsigned NumFound = 0;
630   for (Instruction &I : *BB) {
631     if (!is_contained(Chain, &I))
632       continue;
633 
634     ++NumFound;
635     if (NumFound == 1) {
636       FirstInstr = I.getIterator();
637     }
638     if (NumFound == Chain.size()) {
639       LastInstr = I.getIterator();
640       break;
641     }
642   }
643 
644   // Range is [first, last).
645   return std::make_pair(FirstInstr, ++LastInstr);
646 }
647 
eraseInstructions(ArrayRef<Instruction * > Chain)648 void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) {
649   SmallVector<Instruction *, 16> Instrs;
650   for (Instruction *I : Chain) {
651     Value *PtrOperand = getLoadStorePointerOperand(I);
652     assert(PtrOperand && "Instruction must have a pointer operand.");
653     Instrs.push_back(I);
654     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand))
655       Instrs.push_back(GEP);
656   }
657 
658   // Erase instructions.
659   for (Instruction *I : Instrs)
660     if (I->use_empty())
661       I->eraseFromParent();
662 }
663 
664 std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
splitOddVectorElts(ArrayRef<Instruction * > Chain,unsigned ElementSizeBits)665 Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
666                                unsigned ElementSizeBits) {
667   unsigned ElementSizeBytes = ElementSizeBits / 8;
668   unsigned SizeBytes = ElementSizeBytes * Chain.size();
669   unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
670   if (NumLeft == Chain.size()) {
671     if ((NumLeft & 1) == 0)
672       NumLeft /= 2; // Split even in half
673     else
674       --NumLeft;    // Split off last element
675   } else if (NumLeft == 0)
676     NumLeft = 1;
677   return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
678 }
679 
680 ArrayRef<Instruction *>
getVectorizablePrefix(ArrayRef<Instruction * > Chain)681 Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
682   // These are in BB order, unlike Chain, which is in address order.
683   SmallVector<Instruction *, 16> MemoryInstrs;
684   SmallVector<Instruction *, 16> ChainInstrs;
685 
686   bool IsLoadChain = isa<LoadInst>(Chain[0]);
687   LLVM_DEBUG({
688     for (Instruction *I : Chain) {
689       if (IsLoadChain)
690         assert(isa<LoadInst>(I) &&
691                "All elements of Chain must be loads, or all must be stores.");
692       else
693         assert(isa<StoreInst>(I) &&
694                "All elements of Chain must be loads, or all must be stores.");
695     }
696   });
697 
698   for (Instruction &I : make_range(getBoundaryInstrs(Chain))) {
699     if ((isa<LoadInst>(I) || isa<StoreInst>(I)) && is_contained(Chain, &I)) {
700       ChainInstrs.push_back(&I);
701       continue;
702     }
703     if (!isGuaranteedToTransferExecutionToSuccessor(&I)) {
704       LLVM_DEBUG(dbgs() << "LSV: Found instruction may not transfer execution: "
705                         << I << '\n');
706       break;
707     }
708     if (I.mayReadOrWriteMemory())
709       MemoryInstrs.push_back(&I);
710   }
711 
712   // Loop until we find an instruction in ChainInstrs that we can't vectorize.
713   unsigned ChainInstrIdx = 0;
714   Instruction *BarrierMemoryInstr = nullptr;
715 
716   for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) {
717     Instruction *ChainInstr = ChainInstrs[ChainInstrIdx];
718 
719     // If a barrier memory instruction was found, chain instructions that follow
720     // will not be added to the valid prefix.
721     if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr))
722       break;
723 
724     // Check (in BB order) if any instruction prevents ChainInstr from being
725     // vectorized. Find and store the first such "conflicting" instruction.
726     for (Instruction *MemInstr : MemoryInstrs) {
727       // If a barrier memory instruction was found, do not check past it.
728       if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr))
729         break;
730 
731       auto *MemLoad = dyn_cast<LoadInst>(MemInstr);
732       auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr);
733       if (MemLoad && ChainLoad)
734         continue;
735 
736       // We can ignore the alias if the we have a load store pair and the load
737       // is known to be invariant. The load cannot be clobbered by the store.
738       auto IsInvariantLoad = [](const LoadInst *LI) -> bool {
739         return LI->hasMetadata(LLVMContext::MD_invariant_load);
740       };
741 
742       if (IsLoadChain) {
743         // We can ignore the alias as long as the load comes before the store,
744         // because that means we won't be moving the load past the store to
745         // vectorize it (the vectorized load is inserted at the location of the
746         // first load in the chain).
747         if (ChainInstr->comesBefore(MemInstr) ||
748             (ChainLoad && IsInvariantLoad(ChainLoad)))
749           continue;
750       } else {
751         // Same case, but in reverse.
752         if (MemInstr->comesBefore(ChainInstr) ||
753             (MemLoad && IsInvariantLoad(MemLoad)))
754           continue;
755       }
756 
757       ModRefInfo MR =
758           AA.getModRefInfo(MemInstr, MemoryLocation::get(ChainInstr));
759       if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) {
760         LLVM_DEBUG({
761           dbgs() << "LSV: Found alias:\n"
762                     "  Aliasing instruction:\n"
763                  << "  " << *MemInstr << '\n'
764                  << "  Aliased instruction and pointer:\n"
765                  << "  " << *ChainInstr << '\n'
766                  << "  " << *getLoadStorePointerOperand(ChainInstr) << '\n';
767         });
768         // Save this aliasing memory instruction as a barrier, but allow other
769         // instructions that precede the barrier to be vectorized with this one.
770         BarrierMemoryInstr = MemInstr;
771         break;
772       }
773     }
774     // Continue the search only for store chains, since vectorizing stores that
775     // precede an aliasing load is valid. Conversely, vectorizing loads is valid
776     // up to an aliasing store, but should not pull loads from further down in
777     // the basic block.
778     if (IsLoadChain && BarrierMemoryInstr) {
779       // The BarrierMemoryInstr is a store that precedes ChainInstr.
780       assert(BarrierMemoryInstr->comesBefore(ChainInstr));
781       break;
782     }
783   }
784 
785   // Find the largest prefix of Chain whose elements are all in
786   // ChainInstrs[0, ChainInstrIdx).  This is the largest vectorizable prefix of
787   // Chain.  (Recall that Chain is in address order, but ChainInstrs is in BB
788   // order.)
789   SmallPtrSet<Instruction *, 8> VectorizableChainInstrs(
790       ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx);
791   unsigned ChainIdx = 0;
792   for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
793     if (!VectorizableChainInstrs.count(Chain[ChainIdx]))
794       break;
795   }
796   return Chain.slice(0, ChainIdx);
797 }
798 
getChainID(const Value * Ptr)799 static ChainID getChainID(const Value *Ptr) {
800   const Value *ObjPtr = getUnderlyingObject(Ptr);
801   if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
802     // The select's themselves are distinct instructions even if they share the
803     // same condition and evaluate to consecutive pointers for true and false
804     // values of the condition. Therefore using the select's themselves for
805     // grouping instructions would put consecutive accesses into different lists
806     // and they won't be even checked for being consecutive, and won't be
807     // vectorized.
808     return Sel->getCondition();
809   }
810   return ObjPtr;
811 }
812 
813 std::pair<InstrListMap, InstrListMap>
collectInstructions(BasicBlock * BB)814 Vectorizer::collectInstructions(BasicBlock *BB) {
815   InstrListMap LoadRefs;
816   InstrListMap StoreRefs;
817 
818   for (Instruction &I : *BB) {
819     if (!I.mayReadOrWriteMemory())
820       continue;
821 
822     if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
823       if (!LI->isSimple())
824         continue;
825 
826       // Skip if it's not legal.
827       if (!TTI.isLegalToVectorizeLoad(LI))
828         continue;
829 
830       Type *Ty = LI->getType();
831       if (!VectorType::isValidElementType(Ty->getScalarType()))
832         continue;
833 
834       // Skip weird non-byte sizes. They probably aren't worth the effort of
835       // handling correctly.
836       unsigned TySize = DL.getTypeSizeInBits(Ty);
837       if ((TySize % 8) != 0)
838         continue;
839 
840       // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
841       // functions are currently using an integer type for the vectorized
842       // load/store, and does not support casting between the integer type and a
843       // vector of pointers (e.g. i64 to <2 x i16*>)
844       if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
845         continue;
846 
847       Value *Ptr = LI->getPointerOperand();
848       unsigned AS = Ptr->getType()->getPointerAddressSpace();
849       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
850 
851       unsigned VF = VecRegSize / TySize;
852       VectorType *VecTy = dyn_cast<VectorType>(Ty);
853 
854       // No point in looking at these if they're too big to vectorize.
855       if (TySize > VecRegSize / 2 ||
856           (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
857         continue;
858 
859       // Save the load locations.
860       const ChainID ID = getChainID(Ptr);
861       LoadRefs[ID].push_back(LI);
862     } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
863       if (!SI->isSimple())
864         continue;
865 
866       // Skip if it's not legal.
867       if (!TTI.isLegalToVectorizeStore(SI))
868         continue;
869 
870       Type *Ty = SI->getValueOperand()->getType();
871       if (!VectorType::isValidElementType(Ty->getScalarType()))
872         continue;
873 
874       // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
875       // functions are currently using an integer type for the vectorized
876       // load/store, and does not support casting between the integer type and a
877       // vector of pointers (e.g. i64 to <2 x i16*>)
878       if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
879         continue;
880 
881       // Skip weird non-byte sizes. They probably aren't worth the effort of
882       // handling correctly.
883       unsigned TySize = DL.getTypeSizeInBits(Ty);
884       if ((TySize % 8) != 0)
885         continue;
886 
887       Value *Ptr = SI->getPointerOperand();
888       unsigned AS = Ptr->getType()->getPointerAddressSpace();
889       unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
890 
891       unsigned VF = VecRegSize / TySize;
892       VectorType *VecTy = dyn_cast<VectorType>(Ty);
893 
894       // No point in looking at these if they're too big to vectorize.
895       if (TySize > VecRegSize / 2 ||
896           (VecTy && TTI.getStoreVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
897         continue;
898 
899       // Save store location.
900       const ChainID ID = getChainID(Ptr);
901       StoreRefs[ID].push_back(SI);
902     }
903   }
904 
905   return {LoadRefs, StoreRefs};
906 }
907 
vectorizeChains(InstrListMap & Map)908 bool Vectorizer::vectorizeChains(InstrListMap &Map) {
909   bool Changed = false;
910 
911   for (const std::pair<ChainID, InstrList> &Chain : Map) {
912     unsigned Size = Chain.second.size();
913     if (Size < 2)
914       continue;
915 
916     LLVM_DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n");
917 
918     // Process the stores in chunks of 64.
919     for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) {
920       unsigned Len = std::min<unsigned>(CE - CI, 64);
921       ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len);
922       Changed |= vectorizeInstructions(Chunk);
923     }
924   }
925 
926   return Changed;
927 }
928 
vectorizeInstructions(ArrayRef<Instruction * > Instrs)929 bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) {
930   LLVM_DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size()
931                     << " instructions.\n");
932   SmallVector<int, 16> Heads, Tails;
933   int ConsecutiveChain[64];
934 
935   // Do a quadratic search on all of the given loads/stores and find all of the
936   // pairs of loads/stores that follow each other.
937   for (int i = 0, e = Instrs.size(); i < e; ++i) {
938     ConsecutiveChain[i] = -1;
939     for (int j = e - 1; j >= 0; --j) {
940       if (i == j)
941         continue;
942 
943       if (isConsecutiveAccess(Instrs[i], Instrs[j])) {
944         if (ConsecutiveChain[i] != -1) {
945           int CurDistance = std::abs(ConsecutiveChain[i] - i);
946           int NewDistance = std::abs(ConsecutiveChain[i] - j);
947           if (j < i || NewDistance > CurDistance)
948             continue; // Should not insert.
949         }
950 
951         Tails.push_back(j);
952         Heads.push_back(i);
953         ConsecutiveChain[i] = j;
954       }
955     }
956   }
957 
958   bool Changed = false;
959   SmallPtrSet<Instruction *, 16> InstructionsProcessed;
960 
961   for (int Head : Heads) {
962     if (InstructionsProcessed.count(Instrs[Head]))
963       continue;
964     bool LongerChainExists = false;
965     for (unsigned TIt = 0; TIt < Tails.size(); TIt++)
966       if (Head == Tails[TIt] &&
967           !InstructionsProcessed.count(Instrs[Heads[TIt]])) {
968         LongerChainExists = true;
969         break;
970       }
971     if (LongerChainExists)
972       continue;
973 
974     // We found an instr that starts a chain. Now follow the chain and try to
975     // vectorize it.
976     SmallVector<Instruction *, 16> Operands;
977     int I = Head;
978     while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) {
979       if (InstructionsProcessed.count(Instrs[I]))
980         break;
981 
982       Operands.push_back(Instrs[I]);
983       I = ConsecutiveChain[I];
984     }
985 
986     bool Vectorized = false;
987     if (isa<LoadInst>(*Operands.begin()))
988       Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed);
989     else
990       Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed);
991 
992     Changed |= Vectorized;
993   }
994 
995   return Changed;
996 }
997 
vectorizeStoreChain(ArrayRef<Instruction * > Chain,SmallPtrSet<Instruction *,16> * InstructionsProcessed)998 bool Vectorizer::vectorizeStoreChain(
999     ArrayRef<Instruction *> Chain,
1000     SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
1001   StoreInst *S0 = cast<StoreInst>(Chain[0]);
1002 
1003   // If the vector has an int element, default to int for the whole store.
1004   Type *StoreTy = nullptr;
1005   for (Instruction *I : Chain) {
1006     StoreTy = cast<StoreInst>(I)->getValueOperand()->getType();
1007     if (StoreTy->isIntOrIntVectorTy())
1008       break;
1009 
1010     if (StoreTy->isPtrOrPtrVectorTy()) {
1011       StoreTy = Type::getIntNTy(F.getParent()->getContext(),
1012                                 DL.getTypeSizeInBits(StoreTy));
1013       break;
1014     }
1015   }
1016   assert(StoreTy && "Failed to find store type");
1017 
1018   unsigned Sz = DL.getTypeSizeInBits(StoreTy);
1019   unsigned AS = S0->getPointerAddressSpace();
1020   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1021   unsigned VF = VecRegSize / Sz;
1022   unsigned ChainSize = Chain.size();
1023   Align Alignment = S0->getAlign();
1024 
1025   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
1026     InstructionsProcessed->insert(Chain.begin(), Chain.end());
1027     return false;
1028   }
1029 
1030   ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
1031   if (NewChain.empty()) {
1032     // No vectorization possible.
1033     InstructionsProcessed->insert(Chain.begin(), Chain.end());
1034     return false;
1035   }
1036   if (NewChain.size() == 1) {
1037     // Failed after the first instruction. Discard it and try the smaller chain.
1038     InstructionsProcessed->insert(NewChain.front());
1039     return false;
1040   }
1041 
1042   // Update Chain to the valid vectorizable subchain.
1043   Chain = NewChain;
1044   ChainSize = Chain.size();
1045 
1046   // Check if it's legal to vectorize this chain. If not, split the chain and
1047   // try again.
1048   unsigned EltSzInBytes = Sz / 8;
1049   unsigned SzInBytes = EltSzInBytes * ChainSize;
1050 
1051   FixedVectorType *VecTy;
1052   auto *VecStoreTy = dyn_cast<FixedVectorType>(StoreTy);
1053   if (VecStoreTy)
1054     VecTy = FixedVectorType::get(StoreTy->getScalarType(),
1055                                  Chain.size() * VecStoreTy->getNumElements());
1056   else
1057     VecTy = FixedVectorType::get(StoreTy, Chain.size());
1058 
1059   // If it's more than the max vector size or the target has a better
1060   // vector factor, break it into two pieces.
1061   unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy);
1062   if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1063     LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
1064                          " Creating two separate arrays.\n");
1065     bool Vectorized = false;
1066     Vectorized |=
1067         vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed);
1068     Vectorized |=
1069         vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed);
1070     return Vectorized;
1071   }
1072 
1073   LLVM_DEBUG({
1074     dbgs() << "LSV: Stores to vectorize:\n";
1075     for (Instruction *I : Chain)
1076       dbgs() << "  " << *I << "\n";
1077   });
1078 
1079   // We won't try again to vectorize the elements of the chain, regardless of
1080   // whether we succeed below.
1081   InstructionsProcessed->insert(Chain.begin(), Chain.end());
1082 
1083   // If the store is going to be misaligned, don't vectorize it.
1084   unsigned RelativeSpeed;
1085   if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) {
1086     if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
1087       unsigned SpeedBefore;
1088       accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore);
1089       if (SpeedBefore > RelativeSpeed)
1090         return false;
1091 
1092       auto Chains = splitOddVectorElts(Chain, Sz);
1093       bool Vectorized = false;
1094       Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
1095       Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
1096       return Vectorized;
1097     }
1098 
1099     Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
1100                                                 Align(StackAdjustedAlignment),
1101                                                 DL, S0, nullptr, &DT);
1102     if (NewAlign >= Alignment)
1103       Alignment = NewAlign;
1104     else
1105       return false;
1106   }
1107 
1108   if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
1109     auto Chains = splitOddVectorElts(Chain, Sz);
1110     bool Vectorized = false;
1111     Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed);
1112     Vectorized |= vectorizeStoreChain(Chains.second, InstructionsProcessed);
1113     return Vectorized;
1114   }
1115 
1116   BasicBlock::iterator First, Last;
1117   std::tie(First, Last) = getBoundaryInstrs(Chain);
1118   Builder.SetInsertPoint(&*Last);
1119 
1120   Value *Vec = PoisonValue::get(VecTy);
1121 
1122   if (VecStoreTy) {
1123     unsigned VecWidth = VecStoreTy->getNumElements();
1124     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1125       StoreInst *Store = cast<StoreInst>(Chain[I]);
1126       for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) {
1127         unsigned NewIdx = J + I * VecWidth;
1128         Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(),
1129                                                       Builder.getInt32(J));
1130         if (Extract->getType() != StoreTy->getScalarType())
1131           Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType());
1132 
1133         Value *Insert =
1134             Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx));
1135         Vec = Insert;
1136       }
1137     }
1138   } else {
1139     for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1140       StoreInst *Store = cast<StoreInst>(Chain[I]);
1141       Value *Extract = Store->getValueOperand();
1142       if (Extract->getType() != StoreTy->getScalarType())
1143         Extract =
1144             Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType());
1145 
1146       Value *Insert =
1147           Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I));
1148       Vec = Insert;
1149     }
1150   }
1151 
1152   StoreInst *SI = Builder.CreateAlignedStore(
1153     Vec,
1154     Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)),
1155     Alignment);
1156   propagateMetadata(SI, Chain);
1157 
1158   eraseInstructions(Chain);
1159   ++NumVectorInstructions;
1160   NumScalarsVectorized += Chain.size();
1161   return true;
1162 }
1163 
vectorizeLoadChain(ArrayRef<Instruction * > Chain,SmallPtrSet<Instruction *,16> * InstructionsProcessed)1164 bool Vectorizer::vectorizeLoadChain(
1165     ArrayRef<Instruction *> Chain,
1166     SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
1167   LoadInst *L0 = cast<LoadInst>(Chain[0]);
1168 
1169   // If the vector has an int element, default to int for the whole load.
1170   Type *LoadTy = nullptr;
1171   for (const auto &V : Chain) {
1172     LoadTy = cast<LoadInst>(V)->getType();
1173     if (LoadTy->isIntOrIntVectorTy())
1174       break;
1175 
1176     if (LoadTy->isPtrOrPtrVectorTy()) {
1177       LoadTy = Type::getIntNTy(F.getParent()->getContext(),
1178                                DL.getTypeSizeInBits(LoadTy));
1179       break;
1180     }
1181   }
1182   assert(LoadTy && "Can't determine LoadInst type from chain");
1183 
1184   unsigned Sz = DL.getTypeSizeInBits(LoadTy);
1185   unsigned AS = L0->getPointerAddressSpace();
1186   unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1187   unsigned VF = VecRegSize / Sz;
1188   unsigned ChainSize = Chain.size();
1189   Align Alignment = L0->getAlign();
1190 
1191   if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
1192     InstructionsProcessed->insert(Chain.begin(), Chain.end());
1193     return false;
1194   }
1195 
1196   ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
1197   if (NewChain.empty()) {
1198     // No vectorization possible.
1199     InstructionsProcessed->insert(Chain.begin(), Chain.end());
1200     return false;
1201   }
1202   if (NewChain.size() == 1) {
1203     // Failed after the first instruction. Discard it and try the smaller chain.
1204     InstructionsProcessed->insert(NewChain.front());
1205     return false;
1206   }
1207 
1208   // Update Chain to the valid vectorizable subchain.
1209   Chain = NewChain;
1210   ChainSize = Chain.size();
1211 
1212   // Check if it's legal to vectorize this chain. If not, split the chain and
1213   // try again.
1214   unsigned EltSzInBytes = Sz / 8;
1215   unsigned SzInBytes = EltSzInBytes * ChainSize;
1216   VectorType *VecTy;
1217   auto *VecLoadTy = dyn_cast<FixedVectorType>(LoadTy);
1218   if (VecLoadTy)
1219     VecTy = FixedVectorType::get(LoadTy->getScalarType(),
1220                                  Chain.size() * VecLoadTy->getNumElements());
1221   else
1222     VecTy = FixedVectorType::get(LoadTy, Chain.size());
1223 
1224   // If it's more than the max vector size or the target has a better
1225   // vector factor, break it into two pieces.
1226   unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy);
1227   if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1228     LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
1229                          " Creating two separate arrays.\n");
1230     bool Vectorized = false;
1231     Vectorized |=
1232         vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed);
1233     Vectorized |=
1234         vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed);
1235     return Vectorized;
1236   }
1237 
1238   // We won't try again to vectorize the elements of the chain, regardless of
1239   // whether we succeed below.
1240   InstructionsProcessed->insert(Chain.begin(), Chain.end());
1241 
1242   // If the load is going to be misaligned, don't vectorize it.
1243   unsigned RelativeSpeed;
1244   if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) {
1245     if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
1246       unsigned SpeedBefore;
1247       accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore);
1248       if (SpeedBefore > RelativeSpeed)
1249         return false;
1250 
1251       auto Chains = splitOddVectorElts(Chain, Sz);
1252       bool Vectorized = false;
1253       Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
1254       Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
1255       return Vectorized;
1256     }
1257 
1258     Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(),
1259                                                 Align(StackAdjustedAlignment),
1260                                                 DL, L0, nullptr, &DT);
1261     if (NewAlign >= Alignment)
1262       Alignment = NewAlign;
1263     else
1264       return false;
1265   }
1266 
1267   if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
1268     auto Chains = splitOddVectorElts(Chain, Sz);
1269     bool Vectorized = false;
1270     Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed);
1271     Vectorized |= vectorizeLoadChain(Chains.second, InstructionsProcessed);
1272     return Vectorized;
1273   }
1274 
1275   LLVM_DEBUG({
1276     dbgs() << "LSV: Loads to vectorize:\n";
1277     for (Instruction *I : Chain)
1278       I->dump();
1279   });
1280 
1281   // getVectorizablePrefix already computed getBoundaryInstrs.  The value of
1282   // Last may have changed since then, but the value of First won't have.  If it
1283   // matters, we could compute getBoundaryInstrs only once and reuse it here.
1284   BasicBlock::iterator First, Last;
1285   std::tie(First, Last) = getBoundaryInstrs(Chain);
1286   Builder.SetInsertPoint(&*First);
1287 
1288   Value *Bitcast =
1289       Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
1290   LoadInst *LI =
1291       Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment));
1292   propagateMetadata(LI, Chain);
1293 
1294   for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1295     Value *CV = Chain[I];
1296     Value *V;
1297     if (VecLoadTy) {
1298       // Extract a subvector using shufflevector.
1299       unsigned VecWidth = VecLoadTy->getNumElements();
1300       auto Mask =
1301           llvm::to_vector<8>(llvm::seq<int>(I * VecWidth, (I + 1) * VecWidth));
1302       V = Builder.CreateShuffleVector(LI, Mask, CV->getName());
1303     } else {
1304       V = Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName());
1305     }
1306 
1307     if (V->getType() != CV->getType()) {
1308       V = Builder.CreateBitOrPointerCast(V, CV->getType());
1309     }
1310 
1311     // Replace the old instruction.
1312     CV->replaceAllUsesWith(V);
1313   }
1314 
1315   // Since we might have opaque pointers we might end up using the pointer
1316   // operand of the first load (wrt. memory loaded) for the vector load. Since
1317   // this first load might not be the first in the block we potentially need to
1318   // reorder the pointer operand (and its operands). If we have a bitcast though
1319   // it might be before the load and should be the reorder start instruction.
1320   // "Might" because for opaque pointers the "bitcast" is just the first loads
1321   // pointer operand, as oppposed to something we inserted at the right position
1322   // ourselves.
1323   Instruction *BCInst = dyn_cast<Instruction>(Bitcast);
1324   reorder((BCInst && BCInst != L0->getPointerOperand()) ? BCInst : LI);
1325 
1326   eraseInstructions(Chain);
1327 
1328   ++NumVectorInstructions;
1329   NumScalarsVectorized += Chain.size();
1330   return true;
1331 }
1332 
accessIsMisaligned(unsigned SzInBytes,unsigned AddressSpace,Align Alignment,unsigned & RelativeSpeed)1333 bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
1334                                     Align Alignment, unsigned &RelativeSpeed) {
1335   RelativeSpeed = 0;
1336   if (Alignment.value() % SzInBytes == 0)
1337     return false;
1338 
1339   bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(),
1340                                                    SzInBytes * 8, AddressSpace,
1341                                                    Alignment, &RelativeSpeed);
1342   LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows
1343                     << " with relative speed = " << RelativeSpeed << '\n';);
1344   return !Allows || !RelativeSpeed;
1345 }
1346