1 //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass merges loads/stores to/from sequential memory addresses into vector
10 // loads/stores. Although there's nothing GPU-specific in here, this pass is
11 // motivated by the microarchitectural quirks of nVidia and AMD GPUs.
12 //
13 // (For simplicity below we talk about loads only, but everything also applies
14 // to stores.)
15 //
16 // This pass is intended to be run late in the pipeline, after other
17 // vectorization opportunities have been exploited. So the assumption here is
18 // that immediately following our new vector load we'll need to extract out the
19 // individual elements of the load, so we can operate on them individually.
20 //
21 // On CPUs this transformation is usually not beneficial, because extracting the
22 // elements of a vector register is expensive on most architectures. It's
23 // usually better just to load each element individually into its own scalar
24 // register.
25 //
26 // However, nVidia and AMD GPUs don't have proper vector registers. Instead, a
27 // "vector load" loads directly into a series of scalar registers. In effect,
28 // extracting the elements of the vector is free. It's therefore always
29 // beneficial to vectorize a sequence of loads on these architectures.
30 //
31 // Vectorizing (perhaps a better name might be "coalescing") loads can have
32 // large performance impacts on GPU kernels, and opportunities for vectorizing
33 // are common in GPU code. This pass tries very hard to find such
34 // opportunities; its runtime is quadratic in the number of loads in a BB.
35 //
36 // Some CPU architectures, such as ARM, have instructions that load into
37 // multiple scalar registers, similar to a GPU vectorized load. In theory ARM
38 // could use this pass (with some modifications), but currently it implements
39 // its own pass to do something similar to what we do here.
40
41 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
42 #include "llvm/ADT/APInt.h"
43 #include "llvm/ADT/ArrayRef.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/PostOrderIterator.h"
46 #include "llvm/ADT/STLExtras.h"
47 #include "llvm/ADT/SmallPtrSet.h"
48 #include "llvm/ADT/SmallVector.h"
49 #include "llvm/ADT/Statistic.h"
50 #include "llvm/ADT/iterator_range.h"
51 #include "llvm/Analysis/AliasAnalysis.h"
52 #include "llvm/Analysis/AssumptionCache.h"
53 #include "llvm/Analysis/MemoryLocation.h"
54 #include "llvm/Analysis/ScalarEvolution.h"
55 #include "llvm/Analysis/TargetTransformInfo.h"
56 #include "llvm/Analysis/ValueTracking.h"
57 #include "llvm/Analysis/VectorUtils.h"
58 #include "llvm/IR/Attributes.h"
59 #include "llvm/IR/BasicBlock.h"
60 #include "llvm/IR/Constants.h"
61 #include "llvm/IR/DataLayout.h"
62 #include "llvm/IR/DerivedTypes.h"
63 #include "llvm/IR/Dominators.h"
64 #include "llvm/IR/Function.h"
65 #include "llvm/IR/IRBuilder.h"
66 #include "llvm/IR/InstrTypes.h"
67 #include "llvm/IR/Instruction.h"
68 #include "llvm/IR/Instructions.h"
69 #include "llvm/IR/IntrinsicInst.h"
70 #include "llvm/IR/Module.h"
71 #include "llvm/IR/Type.h"
72 #include "llvm/IR/User.h"
73 #include "llvm/IR/Value.h"
74 #include "llvm/InitializePasses.h"
75 #include "llvm/Pass.h"
76 #include "llvm/Support/Casting.h"
77 #include "llvm/Support/Debug.h"
78 #include "llvm/Support/KnownBits.h"
79 #include "llvm/Support/MathExtras.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/Transforms/Utils/Local.h"
82 #include "llvm/Transforms/Vectorize.h"
83 #include <algorithm>
84 #include <cassert>
85 #include <cstdlib>
86 #include <tuple>
87 #include <utility>
88
89 using namespace llvm;
90
91 #define DEBUG_TYPE "load-store-vectorizer"
92
93 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
94 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
95
96 // FIXME: Assuming stack alignment of 4 is always good enough
97 static const unsigned StackAdjustedAlignment = 4;
98
99 namespace {
100
101 /// ChainID is an arbitrary token that is allowed to be different only for the
102 /// accesses that are guaranteed to be considered non-consecutive by
103 /// Vectorizer::isConsecutiveAccess. It's used for grouping instructions
104 /// together and reducing the number of instructions the main search operates on
105 /// at a time, i.e. this is to reduce compile time and nothing else as the main
106 /// search has O(n^2) time complexity. The underlying type of ChainID should not
107 /// be relied upon.
108 using ChainID = const Value *;
109 using InstrList = SmallVector<Instruction *, 8>;
110 using InstrListMap = MapVector<ChainID, InstrList>;
111
112 class Vectorizer {
113 Function &F;
114 AliasAnalysis &AA;
115 AssumptionCache &AC;
116 DominatorTree &DT;
117 ScalarEvolution &SE;
118 TargetTransformInfo &TTI;
119 const DataLayout &DL;
120 IRBuilder<> Builder;
121
122 public:
Vectorizer(Function & F,AliasAnalysis & AA,AssumptionCache & AC,DominatorTree & DT,ScalarEvolution & SE,TargetTransformInfo & TTI)123 Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
124 DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
125 : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI),
126 DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {}
127
128 bool run();
129
130 private:
131 unsigned getPointerAddressSpace(Value *I);
132
133 static const unsigned MaxDepth = 3;
134
135 bool isConsecutiveAccess(Value *A, Value *B);
136 bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta,
137 unsigned Depth = 0) const;
138 bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta,
139 unsigned Depth) const;
140 bool lookThroughSelects(Value *PtrA, Value *PtrB, const APInt &PtrDelta,
141 unsigned Depth) const;
142
143 /// After vectorization, reorder the instructions that I depends on
144 /// (the instructions defining its operands), to ensure they dominate I.
145 void reorder(Instruction *I);
146
147 /// Returns the first and the last instructions in Chain.
148 std::pair<BasicBlock::iterator, BasicBlock::iterator>
149 getBoundaryInstrs(ArrayRef<Instruction *> Chain);
150
151 /// Erases the original instructions after vectorizing.
152 void eraseInstructions(ArrayRef<Instruction *> Chain);
153
154 /// "Legalize" the vector type that would be produced by combining \p
155 /// ElementSizeBits elements in \p Chain. Break into two pieces such that the
156 /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is
157 /// expected to have more than 4 elements.
158 std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
159 splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits);
160
161 /// Finds the largest prefix of Chain that's vectorizable, checking for
162 /// intervening instructions which may affect the memory accessed by the
163 /// instructions within Chain.
164 ///
165 /// The elements of \p Chain must be all loads or all stores and must be in
166 /// address order.
167 ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain);
168
169 /// Collects load and store instructions to vectorize.
170 std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB);
171
172 /// Processes the collected instructions, the \p Map. The values of \p Map
173 /// should be all loads or all stores.
174 bool vectorizeChains(InstrListMap &Map);
175
176 /// Finds the load/stores to consecutive memory addresses and vectorizes them.
177 bool vectorizeInstructions(ArrayRef<Instruction *> Instrs);
178
179 /// Vectorizes the load instructions in Chain.
180 bool
181 vectorizeLoadChain(ArrayRef<Instruction *> Chain,
182 SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
183
184 /// Vectorizes the store instructions in Chain.
185 bool
186 vectorizeStoreChain(ArrayRef<Instruction *> Chain,
187 SmallPtrSet<Instruction *, 16> *InstructionsProcessed);
188
189 /// Check if this load/store access is misaligned accesses.
190 bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
191 Align Alignment);
192 };
193
194 class LoadStoreVectorizerLegacyPass : public FunctionPass {
195 public:
196 static char ID;
197
LoadStoreVectorizerLegacyPass()198 LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {
199 initializeLoadStoreVectorizerLegacyPassPass(*PassRegistry::getPassRegistry());
200 }
201
202 bool runOnFunction(Function &F) override;
203
getPassName() const204 StringRef getPassName() const override {
205 return "GPU Load and Store Vectorizer";
206 }
207
getAnalysisUsage(AnalysisUsage & AU) const208 void getAnalysisUsage(AnalysisUsage &AU) const override {
209 AU.addRequired<AAResultsWrapperPass>();
210 AU.addRequired<AssumptionCacheTracker>();
211 AU.addRequired<ScalarEvolutionWrapperPass>();
212 AU.addRequired<DominatorTreeWrapperPass>();
213 AU.addRequired<TargetTransformInfoWrapperPass>();
214 AU.setPreservesCFG();
215 }
216 };
217
218 } // end anonymous namespace
219
220 char LoadStoreVectorizerLegacyPass::ID = 0;
221
222 INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
223 "Vectorize load and Store instructions", false, false)
224 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
225 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker);
226 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)227 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
228 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
229 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
230 INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
231 "Vectorize load and store instructions", false, false)
232
233 Pass *llvm::createLoadStoreVectorizerPass() {
234 return new LoadStoreVectorizerLegacyPass();
235 }
236
runOnFunction(Function & F)237 bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
238 // Don't vectorize when the attribute NoImplicitFloat is used.
239 if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
240 return false;
241
242 AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
243 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
244 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
245 TargetTransformInfo &TTI =
246 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
247
248 AssumptionCache &AC =
249 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
250
251 Vectorizer V(F, AA, AC, DT, SE, TTI);
252 return V.run();
253 }
254
run(Function & F,FunctionAnalysisManager & AM)255 PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) {
256 // Don't vectorize when the attribute NoImplicitFloat is used.
257 if (F.hasFnAttribute(Attribute::NoImplicitFloat))
258 return PreservedAnalyses::all();
259
260 AliasAnalysis &AA = AM.getResult<AAManager>(F);
261 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
262 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
263 TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
264 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
265
266 Vectorizer V(F, AA, AC, DT, SE, TTI);
267 bool Changed = V.run();
268 PreservedAnalyses PA;
269 PA.preserveSet<CFGAnalyses>();
270 return Changed ? PA : PreservedAnalyses::all();
271 }
272
273 // The real propagateMetadata expects a SmallVector<Value*>, but we deal in
274 // vectors of Instructions.
propagateMetadata(Instruction * I,ArrayRef<Instruction * > IL)275 static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) {
276 SmallVector<Value *, 8> VL(IL.begin(), IL.end());
277 propagateMetadata(I, VL);
278 }
279
280 // Vectorizer Implementation
run()281 bool Vectorizer::run() {
282 bool Changed = false;
283
284 // Scan the blocks in the function in post order.
285 for (BasicBlock *BB : post_order(&F)) {
286 InstrListMap LoadRefs, StoreRefs;
287 std::tie(LoadRefs, StoreRefs) = collectInstructions(BB);
288 Changed |= vectorizeChains(LoadRefs);
289 Changed |= vectorizeChains(StoreRefs);
290 }
291
292 return Changed;
293 }
294
getPointerAddressSpace(Value * I)295 unsigned Vectorizer::getPointerAddressSpace(Value *I) {
296 if (LoadInst *L = dyn_cast<LoadInst>(I))
297 return L->getPointerAddressSpace();
298 if (StoreInst *S = dyn_cast<StoreInst>(I))
299 return S->getPointerAddressSpace();
300 return -1;
301 }
302
303 // FIXME: Merge with llvm::isConsecutiveAccess
isConsecutiveAccess(Value * A,Value * B)304 bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) {
305 Value *PtrA = getLoadStorePointerOperand(A);
306 Value *PtrB = getLoadStorePointerOperand(B);
307 unsigned ASA = getPointerAddressSpace(A);
308 unsigned ASB = getPointerAddressSpace(B);
309
310 // Check that the address spaces match and that the pointers are valid.
311 if (!PtrA || !PtrB || (ASA != ASB))
312 return false;
313
314 // Make sure that A and B are different pointers of the same size type.
315 Type *PtrATy = PtrA->getType()->getPointerElementType();
316 Type *PtrBTy = PtrB->getType()->getPointerElementType();
317 if (PtrA == PtrB ||
318 PtrATy->isVectorTy() != PtrBTy->isVectorTy() ||
319 DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) ||
320 DL.getTypeStoreSize(PtrATy->getScalarType()) !=
321 DL.getTypeStoreSize(PtrBTy->getScalarType()))
322 return false;
323
324 unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA);
325 APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy));
326
327 return areConsecutivePointers(PtrA, PtrB, Size);
328 }
329
areConsecutivePointers(Value * PtrA,Value * PtrB,APInt PtrDelta,unsigned Depth) const330 bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB,
331 APInt PtrDelta, unsigned Depth) const {
332 unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType());
333 APInt OffsetA(PtrBitWidth, 0);
334 APInt OffsetB(PtrBitWidth, 0);
335 PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
336 PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
337
338 unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
339
340 if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
341 return false;
342
343 // In case if we have to shrink the pointer
344 // stripAndAccumulateInBoundsConstantOffsets should properly handle a
345 // possible overflow and the value should fit into a smallest data type
346 // used in the cast/gep chain.
347 assert(OffsetA.getMinSignedBits() <= NewPtrBitWidth &&
348 OffsetB.getMinSignedBits() <= NewPtrBitWidth);
349
350 OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
351 OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
352 PtrDelta = PtrDelta.sextOrTrunc(NewPtrBitWidth);
353
354 APInt OffsetDelta = OffsetB - OffsetA;
355
356 // Check if they are based on the same pointer. That makes the offsets
357 // sufficient.
358 if (PtrA == PtrB)
359 return OffsetDelta == PtrDelta;
360
361 // Compute the necessary base pointer delta to have the necessary final delta
362 // equal to the pointer delta requested.
363 APInt BaseDelta = PtrDelta - OffsetDelta;
364
365 // Compute the distance with SCEV between the base pointers.
366 const SCEV *PtrSCEVA = SE.getSCEV(PtrA);
367 const SCEV *PtrSCEVB = SE.getSCEV(PtrB);
368 const SCEV *C = SE.getConstant(BaseDelta);
369 const SCEV *X = SE.getAddExpr(PtrSCEVA, C);
370 if (X == PtrSCEVB)
371 return true;
372
373 // The above check will not catch the cases where one of the pointers is
374 // factorized but the other one is not, such as (C + (S * (A + B))) vs
375 // (AS + BS). Get the minus scev. That will allow re-combining the expresions
376 // and getting the simplified difference.
377 const SCEV *Dist = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA);
378 if (C == Dist)
379 return true;
380
381 // Sometimes even this doesn't work, because SCEV can't always see through
382 // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking
383 // things the hard way.
384 return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth);
385 }
386
lookThroughComplexAddresses(Value * PtrA,Value * PtrB,APInt PtrDelta,unsigned Depth) const387 bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
388 APInt PtrDelta,
389 unsigned Depth) const {
390 auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
391 auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
392 if (!GEPA || !GEPB)
393 return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth);
394
395 // Look through GEPs after checking they're the same except for the last
396 // index.
397 if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
398 GEPA->getPointerOperand() != GEPB->getPointerOperand())
399 return false;
400 gep_type_iterator GTIA = gep_type_begin(GEPA);
401 gep_type_iterator GTIB = gep_type_begin(GEPB);
402 for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
403 if (GTIA.getOperand() != GTIB.getOperand())
404 return false;
405 ++GTIA;
406 ++GTIB;
407 }
408
409 Instruction *OpA = dyn_cast<Instruction>(GTIA.getOperand());
410 Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand());
411 if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
412 OpA->getType() != OpB->getType())
413 return false;
414
415 if (PtrDelta.isNegative()) {
416 if (PtrDelta.isMinSignedValue())
417 return false;
418 PtrDelta.negate();
419 std::swap(OpA, OpB);
420 }
421 uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType());
422 if (PtrDelta.urem(Stride) != 0)
423 return false;
424 unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits();
425 APInt IdxDiff = PtrDelta.udiv(Stride).zextOrSelf(IdxBitWidth);
426
427 // Only look through a ZExt/SExt.
428 if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
429 return false;
430
431 bool Signed = isa<SExtInst>(OpA);
432
433 // At this point A could be a function parameter, i.e. not an instruction
434 Value *ValA = OpA->getOperand(0);
435 OpB = dyn_cast<Instruction>(OpB->getOperand(0));
436 if (!OpB || ValA->getType() != OpB->getType())
437 return false;
438
439 // Now we need to prove that adding IdxDiff to ValA won't overflow.
440 bool Safe = false;
441 auto CheckFlags = [](Instruction *I, bool Signed) {
442 BinaryOperator *BinOpI = cast<BinaryOperator>(I);
443 return (Signed && BinOpI->hasNoSignedWrap()) ||
444 (!Signed && BinOpI->hasNoUnsignedWrap());
445 };
446
447 // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
448 // ValA, we're okay.
449 if (OpB->getOpcode() == Instruction::Add &&
450 isa<ConstantInt>(OpB->getOperand(1)) &&
451 IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
452 CheckFlags(OpB, Signed))
453 Safe = true;
454
455 // Second attempt: If both OpA and OpB is an add with NSW/NUW and with
456 // the same LHS operand, we can guarantee that the transformation is safe
457 // if we can prove that OpA won't overflow when IdxDiff added to the RHS
458 // of OpA.
459 // For example:
460 // %tmp7 = add nsw i32 %tmp2, %v0
461 // %tmp8 = sext i32 %tmp7 to i64
462 // ...
463 // %tmp11 = add nsw i32 %v0, 1
464 // %tmp12 = add nsw i32 %tmp2, %tmp11
465 // %tmp13 = sext i32 %tmp12 to i64
466 //
467 // Both %tmp7 and %tmp2 has the nsw flag and the first operand
468 // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
469 // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
470 // nsw flag.
471 OpA = dyn_cast<Instruction>(ValA);
472 if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
473 OpB->getOpcode() == Instruction::Add &&
474 OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) &&
475 CheckFlags(OpB, Signed)) {
476 Value *RHSA = OpA->getOperand(1);
477 Value *RHSB = OpB->getOperand(1);
478 Instruction *OpRHSA = dyn_cast<Instruction>(RHSA);
479 Instruction *OpRHSB = dyn_cast<Instruction>(RHSB);
480 // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
481 if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add &&
482 CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) {
483 int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
484 if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal)
485 Safe = true;
486 }
487 // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
488 if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add &&
489 CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) {
490 int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
491 if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal)
492 Safe = true;
493 }
494 // Match `x +nsw/nuw (y +nsw/nuw c)` and
495 // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
496 if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add &&
497 OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) &&
498 CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) &&
499 isa<ConstantInt>(OpRHSB->getOperand(1))) {
500 int64_t CstValA =
501 cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
502 int64_t CstValB =
503 cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
504 if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) &&
505 IdxDiff.getSExtValue() == (CstValB - CstValA))
506 Safe = true;
507 }
508 }
509
510 unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
511
512 // Third attempt:
513 // If all set bits of IdxDiff or any higher order bit other than the sign bit
514 // are known to be zero in ValA, we can add Diff to it while guaranteeing no
515 // overflow of any sort.
516 if (!Safe) {
517 KnownBits Known(BitWidth);
518 computeKnownBits(ValA, Known, DL, 0, &AC, OpB, &DT);
519 APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth());
520 if (Signed)
521 BitsAllowedToBeSet.clearBit(BitWidth - 1);
522 if (BitsAllowedToBeSet.ult(IdxDiff))
523 return false;
524 }
525
526 const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
527 const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
528 const SCEV *C = SE.getConstant(IdxDiff.trunc(BitWidth));
529 const SCEV *X = SE.getAddExpr(OffsetSCEVA, C);
530 return X == OffsetSCEVB;
531 }
532
lookThroughSelects(Value * PtrA,Value * PtrB,const APInt & PtrDelta,unsigned Depth) const533 bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB,
534 const APInt &PtrDelta,
535 unsigned Depth) const {
536 if (Depth++ == MaxDepth)
537 return false;
538
539 if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
540 if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
541 return SelectA->getCondition() == SelectB->getCondition() &&
542 areConsecutivePointers(SelectA->getTrueValue(),
543 SelectB->getTrueValue(), PtrDelta, Depth) &&
544 areConsecutivePointers(SelectA->getFalseValue(),
545 SelectB->getFalseValue(), PtrDelta, Depth);
546 }
547 }
548 return false;
549 }
550
reorder(Instruction * I)551 void Vectorizer::reorder(Instruction *I) {
552 SmallPtrSet<Instruction *, 16> InstructionsToMove;
553 SmallVector<Instruction *, 16> Worklist;
554
555 Worklist.push_back(I);
556 while (!Worklist.empty()) {
557 Instruction *IW = Worklist.pop_back_val();
558 int NumOperands = IW->getNumOperands();
559 for (int i = 0; i < NumOperands; i++) {
560 Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
561 if (!IM || IM->getOpcode() == Instruction::PHI)
562 continue;
563
564 // If IM is in another BB, no need to move it, because this pass only
565 // vectorizes instructions within one BB.
566 if (IM->getParent() != I->getParent())
567 continue;
568
569 if (!IM->comesBefore(I)) {
570 InstructionsToMove.insert(IM);
571 Worklist.push_back(IM);
572 }
573 }
574 }
575
576 // All instructions to move should follow I. Start from I, not from begin().
577 for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;
578 ++BBI) {
579 if (!InstructionsToMove.count(&*BBI))
580 continue;
581 Instruction *IM = &*BBI;
582 --BBI;
583 IM->removeFromParent();
584 IM->insertBefore(I);
585 }
586 }
587
588 std::pair<BasicBlock::iterator, BasicBlock::iterator>
getBoundaryInstrs(ArrayRef<Instruction * > Chain)589 Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) {
590 Instruction *C0 = Chain[0];
591 BasicBlock::iterator FirstInstr = C0->getIterator();
592 BasicBlock::iterator LastInstr = C0->getIterator();
593
594 BasicBlock *BB = C0->getParent();
595 unsigned NumFound = 0;
596 for (Instruction &I : *BB) {
597 if (!is_contained(Chain, &I))
598 continue;
599
600 ++NumFound;
601 if (NumFound == 1) {
602 FirstInstr = I.getIterator();
603 }
604 if (NumFound == Chain.size()) {
605 LastInstr = I.getIterator();
606 break;
607 }
608 }
609
610 // Range is [first, last).
611 return std::make_pair(FirstInstr, ++LastInstr);
612 }
613
eraseInstructions(ArrayRef<Instruction * > Chain)614 void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) {
615 SmallVector<Instruction *, 16> Instrs;
616 for (Instruction *I : Chain) {
617 Value *PtrOperand = getLoadStorePointerOperand(I);
618 assert(PtrOperand && "Instruction must have a pointer operand.");
619 Instrs.push_back(I);
620 if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand))
621 Instrs.push_back(GEP);
622 }
623
624 // Erase instructions.
625 for (Instruction *I : Instrs)
626 if (I->use_empty())
627 I->eraseFromParent();
628 }
629
630 std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>>
splitOddVectorElts(ArrayRef<Instruction * > Chain,unsigned ElementSizeBits)631 Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain,
632 unsigned ElementSizeBits) {
633 unsigned ElementSizeBytes = ElementSizeBits / 8;
634 unsigned SizeBytes = ElementSizeBytes * Chain.size();
635 unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes;
636 if (NumLeft == Chain.size()) {
637 if ((NumLeft & 1) == 0)
638 NumLeft /= 2; // Split even in half
639 else
640 --NumLeft; // Split off last element
641 } else if (NumLeft == 0)
642 NumLeft = 1;
643 return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft));
644 }
645
646 ArrayRef<Instruction *>
getVectorizablePrefix(ArrayRef<Instruction * > Chain)647 Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) {
648 // These are in BB order, unlike Chain, which is in address order.
649 SmallVector<Instruction *, 16> MemoryInstrs;
650 SmallVector<Instruction *, 16> ChainInstrs;
651
652 bool IsLoadChain = isa<LoadInst>(Chain[0]);
653 LLVM_DEBUG({
654 for (Instruction *I : Chain) {
655 if (IsLoadChain)
656 assert(isa<LoadInst>(I) &&
657 "All elements of Chain must be loads, or all must be stores.");
658 else
659 assert(isa<StoreInst>(I) &&
660 "All elements of Chain must be loads, or all must be stores.");
661 }
662 });
663
664 for (Instruction &I : make_range(getBoundaryInstrs(Chain))) {
665 if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
666 if (!is_contained(Chain, &I))
667 MemoryInstrs.push_back(&I);
668 else
669 ChainInstrs.push_back(&I);
670 } else if (isa<IntrinsicInst>(&I) &&
671 cast<IntrinsicInst>(&I)->getIntrinsicID() ==
672 Intrinsic::sideeffect) {
673 // Ignore llvm.sideeffect calls.
674 } else if (isa<IntrinsicInst>(&I) &&
675 cast<IntrinsicInst>(&I)->getIntrinsicID() ==
676 Intrinsic::pseudoprobe) {
677 // Ignore llvm.pseudoprobe calls.
678 } else if (isa<IntrinsicInst>(&I) &&
679 cast<IntrinsicInst>(&I)->getIntrinsicID() == Intrinsic::assume) {
680 // Ignore llvm.assume calls.
681 } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) {
682 LLVM_DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I
683 << '\n');
684 break;
685 } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) {
686 LLVM_DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I
687 << '\n');
688 break;
689 }
690 }
691
692 // Loop until we find an instruction in ChainInstrs that we can't vectorize.
693 unsigned ChainInstrIdx = 0;
694 Instruction *BarrierMemoryInstr = nullptr;
695
696 for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) {
697 Instruction *ChainInstr = ChainInstrs[ChainInstrIdx];
698
699 // If a barrier memory instruction was found, chain instructions that follow
700 // will not be added to the valid prefix.
701 if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(ChainInstr))
702 break;
703
704 // Check (in BB order) if any instruction prevents ChainInstr from being
705 // vectorized. Find and store the first such "conflicting" instruction.
706 for (Instruction *MemInstr : MemoryInstrs) {
707 // If a barrier memory instruction was found, do not check past it.
708 if (BarrierMemoryInstr && BarrierMemoryInstr->comesBefore(MemInstr))
709 break;
710
711 auto *MemLoad = dyn_cast<LoadInst>(MemInstr);
712 auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr);
713 if (MemLoad && ChainLoad)
714 continue;
715
716 // We can ignore the alias if the we have a load store pair and the load
717 // is known to be invariant. The load cannot be clobbered by the store.
718 auto IsInvariantLoad = [](const LoadInst *LI) -> bool {
719 return LI->hasMetadata(LLVMContext::MD_invariant_load);
720 };
721
722 // We can ignore the alias as long as the load comes before the store,
723 // because that means we won't be moving the load past the store to
724 // vectorize it (the vectorized load is inserted at the location of the
725 // first load in the chain).
726 if (isa<StoreInst>(MemInstr) && ChainLoad &&
727 (IsInvariantLoad(ChainLoad) || ChainLoad->comesBefore(MemInstr)))
728 continue;
729
730 // Same case, but in reverse.
731 if (MemLoad && isa<StoreInst>(ChainInstr) &&
732 (IsInvariantLoad(MemLoad) || MemLoad->comesBefore(ChainInstr)))
733 continue;
734
735 if (!AA.isNoAlias(MemoryLocation::get(MemInstr),
736 MemoryLocation::get(ChainInstr))) {
737 LLVM_DEBUG({
738 dbgs() << "LSV: Found alias:\n"
739 " Aliasing instruction and pointer:\n"
740 << " " << *MemInstr << '\n'
741 << " " << *getLoadStorePointerOperand(MemInstr) << '\n'
742 << " Aliased instruction and pointer:\n"
743 << " " << *ChainInstr << '\n'
744 << " " << *getLoadStorePointerOperand(ChainInstr) << '\n';
745 });
746 // Save this aliasing memory instruction as a barrier, but allow other
747 // instructions that precede the barrier to be vectorized with this one.
748 BarrierMemoryInstr = MemInstr;
749 break;
750 }
751 }
752 // Continue the search only for store chains, since vectorizing stores that
753 // precede an aliasing load is valid. Conversely, vectorizing loads is valid
754 // up to an aliasing store, but should not pull loads from further down in
755 // the basic block.
756 if (IsLoadChain && BarrierMemoryInstr) {
757 // The BarrierMemoryInstr is a store that precedes ChainInstr.
758 assert(BarrierMemoryInstr->comesBefore(ChainInstr));
759 break;
760 }
761 }
762
763 // Find the largest prefix of Chain whose elements are all in
764 // ChainInstrs[0, ChainInstrIdx). This is the largest vectorizable prefix of
765 // Chain. (Recall that Chain is in address order, but ChainInstrs is in BB
766 // order.)
767 SmallPtrSet<Instruction *, 8> VectorizableChainInstrs(
768 ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx);
769 unsigned ChainIdx = 0;
770 for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) {
771 if (!VectorizableChainInstrs.count(Chain[ChainIdx]))
772 break;
773 }
774 return Chain.slice(0, ChainIdx);
775 }
776
getChainID(const Value * Ptr)777 static ChainID getChainID(const Value *Ptr) {
778 const Value *ObjPtr = getUnderlyingObject(Ptr);
779 if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
780 // The select's themselves are distinct instructions even if they share the
781 // same condition and evaluate to consecutive pointers for true and false
782 // values of the condition. Therefore using the select's themselves for
783 // grouping instructions would put consecutive accesses into different lists
784 // and they won't be even checked for being consecutive, and won't be
785 // vectorized.
786 return Sel->getCondition();
787 }
788 return ObjPtr;
789 }
790
791 std::pair<InstrListMap, InstrListMap>
collectInstructions(BasicBlock * BB)792 Vectorizer::collectInstructions(BasicBlock *BB) {
793 InstrListMap LoadRefs;
794 InstrListMap StoreRefs;
795
796 for (Instruction &I : *BB) {
797 if (!I.mayReadOrWriteMemory())
798 continue;
799
800 if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
801 if (!LI->isSimple())
802 continue;
803
804 // Skip if it's not legal.
805 if (!TTI.isLegalToVectorizeLoad(LI))
806 continue;
807
808 Type *Ty = LI->getType();
809 if (!VectorType::isValidElementType(Ty->getScalarType()))
810 continue;
811
812 // Skip weird non-byte sizes. They probably aren't worth the effort of
813 // handling correctly.
814 unsigned TySize = DL.getTypeSizeInBits(Ty);
815 if ((TySize % 8) != 0)
816 continue;
817
818 // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
819 // functions are currently using an integer type for the vectorized
820 // load/store, and does not support casting between the integer type and a
821 // vector of pointers (e.g. i64 to <2 x i16*>)
822 if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
823 continue;
824
825 Value *Ptr = LI->getPointerOperand();
826 unsigned AS = Ptr->getType()->getPointerAddressSpace();
827 unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
828
829 unsigned VF = VecRegSize / TySize;
830 VectorType *VecTy = dyn_cast<VectorType>(Ty);
831
832 // No point in looking at these if they're too big to vectorize.
833 if (TySize > VecRegSize / 2 ||
834 (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
835 continue;
836
837 // Make sure all the users of a vector are constant-index extracts.
838 if (isa<VectorType>(Ty) && !llvm::all_of(LI->users(), [](const User *U) {
839 const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
840 return EEI && isa<ConstantInt>(EEI->getOperand(1));
841 }))
842 continue;
843
844 // Save the load locations.
845 const ChainID ID = getChainID(Ptr);
846 LoadRefs[ID].push_back(LI);
847 } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
848 if (!SI->isSimple())
849 continue;
850
851 // Skip if it's not legal.
852 if (!TTI.isLegalToVectorizeStore(SI))
853 continue;
854
855 Type *Ty = SI->getValueOperand()->getType();
856 if (!VectorType::isValidElementType(Ty->getScalarType()))
857 continue;
858
859 // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
860 // functions are currently using an integer type for the vectorized
861 // load/store, and does not support casting between the integer type and a
862 // vector of pointers (e.g. i64 to <2 x i16*>)
863 if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
864 continue;
865
866 // Skip weird non-byte sizes. They probably aren't worth the effort of
867 // handling correctly.
868 unsigned TySize = DL.getTypeSizeInBits(Ty);
869 if ((TySize % 8) != 0)
870 continue;
871
872 Value *Ptr = SI->getPointerOperand();
873 unsigned AS = Ptr->getType()->getPointerAddressSpace();
874 unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
875
876 unsigned VF = VecRegSize / TySize;
877 VectorType *VecTy = dyn_cast<VectorType>(Ty);
878
879 // No point in looking at these if they're too big to vectorize.
880 if (TySize > VecRegSize / 2 ||
881 (VecTy && TTI.getStoreVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
882 continue;
883
884 if (isa<VectorType>(Ty) && !llvm::all_of(SI->users(), [](const User *U) {
885 const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U);
886 return EEI && isa<ConstantInt>(EEI->getOperand(1));
887 }))
888 continue;
889
890 // Save store location.
891 const ChainID ID = getChainID(Ptr);
892 StoreRefs[ID].push_back(SI);
893 }
894 }
895
896 return {LoadRefs, StoreRefs};
897 }
898
vectorizeChains(InstrListMap & Map)899 bool Vectorizer::vectorizeChains(InstrListMap &Map) {
900 bool Changed = false;
901
902 for (const std::pair<ChainID, InstrList> &Chain : Map) {
903 unsigned Size = Chain.second.size();
904 if (Size < 2)
905 continue;
906
907 LLVM_DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n");
908
909 // Process the stores in chunks of 64.
910 for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) {
911 unsigned Len = std::min<unsigned>(CE - CI, 64);
912 ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len);
913 Changed |= vectorizeInstructions(Chunk);
914 }
915 }
916
917 return Changed;
918 }
919
vectorizeInstructions(ArrayRef<Instruction * > Instrs)920 bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) {
921 LLVM_DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size()
922 << " instructions.\n");
923 SmallVector<int, 16> Heads, Tails;
924 int ConsecutiveChain[64];
925
926 // Do a quadratic search on all of the given loads/stores and find all of the
927 // pairs of loads/stores that follow each other.
928 for (int i = 0, e = Instrs.size(); i < e; ++i) {
929 ConsecutiveChain[i] = -1;
930 for (int j = e - 1; j >= 0; --j) {
931 if (i == j)
932 continue;
933
934 if (isConsecutiveAccess(Instrs[i], Instrs[j])) {
935 if (ConsecutiveChain[i] != -1) {
936 int CurDistance = std::abs(ConsecutiveChain[i] - i);
937 int NewDistance = std::abs(ConsecutiveChain[i] - j);
938 if (j < i || NewDistance > CurDistance)
939 continue; // Should not insert.
940 }
941
942 Tails.push_back(j);
943 Heads.push_back(i);
944 ConsecutiveChain[i] = j;
945 }
946 }
947 }
948
949 bool Changed = false;
950 SmallPtrSet<Instruction *, 16> InstructionsProcessed;
951
952 for (int Head : Heads) {
953 if (InstructionsProcessed.count(Instrs[Head]))
954 continue;
955 bool LongerChainExists = false;
956 for (unsigned TIt = 0; TIt < Tails.size(); TIt++)
957 if (Head == Tails[TIt] &&
958 !InstructionsProcessed.count(Instrs[Heads[TIt]])) {
959 LongerChainExists = true;
960 break;
961 }
962 if (LongerChainExists)
963 continue;
964
965 // We found an instr that starts a chain. Now follow the chain and try to
966 // vectorize it.
967 SmallVector<Instruction *, 16> Operands;
968 int I = Head;
969 while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) {
970 if (InstructionsProcessed.count(Instrs[I]))
971 break;
972
973 Operands.push_back(Instrs[I]);
974 I = ConsecutiveChain[I];
975 }
976
977 bool Vectorized = false;
978 if (isa<LoadInst>(*Operands.begin()))
979 Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed);
980 else
981 Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed);
982
983 Changed |= Vectorized;
984 }
985
986 return Changed;
987 }
988
vectorizeStoreChain(ArrayRef<Instruction * > Chain,SmallPtrSet<Instruction *,16> * InstructionsProcessed)989 bool Vectorizer::vectorizeStoreChain(
990 ArrayRef<Instruction *> Chain,
991 SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
992 StoreInst *S0 = cast<StoreInst>(Chain[0]);
993
994 // If the vector has an int element, default to int for the whole store.
995 Type *StoreTy = nullptr;
996 for (Instruction *I : Chain) {
997 StoreTy = cast<StoreInst>(I)->getValueOperand()->getType();
998 if (StoreTy->isIntOrIntVectorTy())
999 break;
1000
1001 if (StoreTy->isPtrOrPtrVectorTy()) {
1002 StoreTy = Type::getIntNTy(F.getParent()->getContext(),
1003 DL.getTypeSizeInBits(StoreTy));
1004 break;
1005 }
1006 }
1007 assert(StoreTy && "Failed to find store type");
1008
1009 unsigned Sz = DL.getTypeSizeInBits(StoreTy);
1010 unsigned AS = S0->getPointerAddressSpace();
1011 unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1012 unsigned VF = VecRegSize / Sz;
1013 unsigned ChainSize = Chain.size();
1014 Align Alignment = S0->getAlign();
1015
1016 if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
1017 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1018 return false;
1019 }
1020
1021 ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
1022 if (NewChain.empty()) {
1023 // No vectorization possible.
1024 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1025 return false;
1026 }
1027 if (NewChain.size() == 1) {
1028 // Failed after the first instruction. Discard it and try the smaller chain.
1029 InstructionsProcessed->insert(NewChain.front());
1030 return false;
1031 }
1032
1033 // Update Chain to the valid vectorizable subchain.
1034 Chain = NewChain;
1035 ChainSize = Chain.size();
1036
1037 // Check if it's legal to vectorize this chain. If not, split the chain and
1038 // try again.
1039 unsigned EltSzInBytes = Sz / 8;
1040 unsigned SzInBytes = EltSzInBytes * ChainSize;
1041
1042 FixedVectorType *VecTy;
1043 auto *VecStoreTy = dyn_cast<FixedVectorType>(StoreTy);
1044 if (VecStoreTy)
1045 VecTy = FixedVectorType::get(StoreTy->getScalarType(),
1046 Chain.size() * VecStoreTy->getNumElements());
1047 else
1048 VecTy = FixedVectorType::get(StoreTy, Chain.size());
1049
1050 // If it's more than the max vector size or the target has a better
1051 // vector factor, break it into two pieces.
1052 unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy);
1053 if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1054 LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
1055 " Creating two separate arrays.\n");
1056 return vectorizeStoreChain(Chain.slice(0, TargetVF),
1057 InstructionsProcessed) |
1058 vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed);
1059 }
1060
1061 LLVM_DEBUG({
1062 dbgs() << "LSV: Stores to vectorize:\n";
1063 for (Instruction *I : Chain)
1064 dbgs() << " " << *I << "\n";
1065 });
1066
1067 // We won't try again to vectorize the elements of the chain, regardless of
1068 // whether we succeed below.
1069 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1070
1071 // If the store is going to be misaligned, don't vectorize it.
1072 if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
1073 if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
1074 auto Chains = splitOddVectorElts(Chain, Sz);
1075 return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
1076 vectorizeStoreChain(Chains.second, InstructionsProcessed);
1077 }
1078
1079 Align NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(),
1080 Align(StackAdjustedAlignment),
1081 DL, S0, nullptr, &DT);
1082 if (NewAlign >= Alignment)
1083 Alignment = NewAlign;
1084 else
1085 return false;
1086 }
1087
1088 if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
1089 auto Chains = splitOddVectorElts(Chain, Sz);
1090 return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
1091 vectorizeStoreChain(Chains.second, InstructionsProcessed);
1092 }
1093
1094 BasicBlock::iterator First, Last;
1095 std::tie(First, Last) = getBoundaryInstrs(Chain);
1096 Builder.SetInsertPoint(&*Last);
1097
1098 Value *Vec = UndefValue::get(VecTy);
1099
1100 if (VecStoreTy) {
1101 unsigned VecWidth = VecStoreTy->getNumElements();
1102 for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1103 StoreInst *Store = cast<StoreInst>(Chain[I]);
1104 for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) {
1105 unsigned NewIdx = J + I * VecWidth;
1106 Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(),
1107 Builder.getInt32(J));
1108 if (Extract->getType() != StoreTy->getScalarType())
1109 Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType());
1110
1111 Value *Insert =
1112 Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx));
1113 Vec = Insert;
1114 }
1115 }
1116 } else {
1117 for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1118 StoreInst *Store = cast<StoreInst>(Chain[I]);
1119 Value *Extract = Store->getValueOperand();
1120 if (Extract->getType() != StoreTy->getScalarType())
1121 Extract =
1122 Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType());
1123
1124 Value *Insert =
1125 Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I));
1126 Vec = Insert;
1127 }
1128 }
1129
1130 StoreInst *SI = Builder.CreateAlignedStore(
1131 Vec,
1132 Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)),
1133 Alignment);
1134 propagateMetadata(SI, Chain);
1135
1136 eraseInstructions(Chain);
1137 ++NumVectorInstructions;
1138 NumScalarsVectorized += Chain.size();
1139 return true;
1140 }
1141
vectorizeLoadChain(ArrayRef<Instruction * > Chain,SmallPtrSet<Instruction *,16> * InstructionsProcessed)1142 bool Vectorizer::vectorizeLoadChain(
1143 ArrayRef<Instruction *> Chain,
1144 SmallPtrSet<Instruction *, 16> *InstructionsProcessed) {
1145 LoadInst *L0 = cast<LoadInst>(Chain[0]);
1146
1147 // If the vector has an int element, default to int for the whole load.
1148 Type *LoadTy = nullptr;
1149 for (const auto &V : Chain) {
1150 LoadTy = cast<LoadInst>(V)->getType();
1151 if (LoadTy->isIntOrIntVectorTy())
1152 break;
1153
1154 if (LoadTy->isPtrOrPtrVectorTy()) {
1155 LoadTy = Type::getIntNTy(F.getParent()->getContext(),
1156 DL.getTypeSizeInBits(LoadTy));
1157 break;
1158 }
1159 }
1160 assert(LoadTy && "Can't determine LoadInst type from chain");
1161
1162 unsigned Sz = DL.getTypeSizeInBits(LoadTy);
1163 unsigned AS = L0->getPointerAddressSpace();
1164 unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1165 unsigned VF = VecRegSize / Sz;
1166 unsigned ChainSize = Chain.size();
1167 Align Alignment = L0->getAlign();
1168
1169 if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) {
1170 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1171 return false;
1172 }
1173
1174 ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain);
1175 if (NewChain.empty()) {
1176 // No vectorization possible.
1177 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1178 return false;
1179 }
1180 if (NewChain.size() == 1) {
1181 // Failed after the first instruction. Discard it and try the smaller chain.
1182 InstructionsProcessed->insert(NewChain.front());
1183 return false;
1184 }
1185
1186 // Update Chain to the valid vectorizable subchain.
1187 Chain = NewChain;
1188 ChainSize = Chain.size();
1189
1190 // Check if it's legal to vectorize this chain. If not, split the chain and
1191 // try again.
1192 unsigned EltSzInBytes = Sz / 8;
1193 unsigned SzInBytes = EltSzInBytes * ChainSize;
1194 VectorType *VecTy;
1195 auto *VecLoadTy = dyn_cast<FixedVectorType>(LoadTy);
1196 if (VecLoadTy)
1197 VecTy = FixedVectorType::get(LoadTy->getScalarType(),
1198 Chain.size() * VecLoadTy->getNumElements());
1199 else
1200 VecTy = FixedVectorType::get(LoadTy, Chain.size());
1201
1202 // If it's more than the max vector size or the target has a better
1203 // vector factor, break it into two pieces.
1204 unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy);
1205 if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) {
1206 LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor."
1207 " Creating two separate arrays.\n");
1208 return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) |
1209 vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed);
1210 }
1211
1212 // We won't try again to vectorize the elements of the chain, regardless of
1213 // whether we succeed below.
1214 InstructionsProcessed->insert(Chain.begin(), Chain.end());
1215
1216 // If the load is going to be misaligned, don't vectorize it.
1217 if (accessIsMisaligned(SzInBytes, AS, Alignment)) {
1218 if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) {
1219 auto Chains = splitOddVectorElts(Chain, Sz);
1220 return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
1221 vectorizeLoadChain(Chains.second, InstructionsProcessed);
1222 }
1223
1224 Align NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(),
1225 Align(StackAdjustedAlignment),
1226 DL, L0, nullptr, &DT);
1227 if (NewAlign >= Alignment)
1228 Alignment = NewAlign;
1229 else
1230 return false;
1231 }
1232
1233 if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
1234 auto Chains = splitOddVectorElts(Chain, Sz);
1235 return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
1236 vectorizeLoadChain(Chains.second, InstructionsProcessed);
1237 }
1238
1239 LLVM_DEBUG({
1240 dbgs() << "LSV: Loads to vectorize:\n";
1241 for (Instruction *I : Chain)
1242 I->dump();
1243 });
1244
1245 // getVectorizablePrefix already computed getBoundaryInstrs. The value of
1246 // Last may have changed since then, but the value of First won't have. If it
1247 // matters, we could compute getBoundaryInstrs only once and reuse it here.
1248 BasicBlock::iterator First, Last;
1249 std::tie(First, Last) = getBoundaryInstrs(Chain);
1250 Builder.SetInsertPoint(&*First);
1251
1252 Value *Bitcast =
1253 Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS));
1254 LoadInst *LI =
1255 Builder.CreateAlignedLoad(VecTy, Bitcast, MaybeAlign(Alignment));
1256 propagateMetadata(LI, Chain);
1257
1258 if (VecLoadTy) {
1259 SmallVector<Instruction *, 16> InstrsToErase;
1260
1261 unsigned VecWidth = VecLoadTy->getNumElements();
1262 for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1263 for (auto Use : Chain[I]->users()) {
1264 // All users of vector loads are ExtractElement instructions with
1265 // constant indices, otherwise we would have bailed before now.
1266 Instruction *UI = cast<Instruction>(Use);
1267 unsigned Idx = cast<ConstantInt>(UI->getOperand(1))->getZExtValue();
1268 unsigned NewIdx = Idx + I * VecWidth;
1269 Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx),
1270 UI->getName());
1271 if (V->getType() != UI->getType())
1272 V = Builder.CreateBitCast(V, UI->getType());
1273
1274 // Replace the old instruction.
1275 UI->replaceAllUsesWith(V);
1276 InstrsToErase.push_back(UI);
1277 }
1278 }
1279
1280 // Bitcast might not be an Instruction, if the value being loaded is a
1281 // constant. In that case, no need to reorder anything.
1282 if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast))
1283 reorder(BitcastInst);
1284
1285 for (auto I : InstrsToErase)
1286 I->eraseFromParent();
1287 } else {
1288 for (unsigned I = 0, E = Chain.size(); I != E; ++I) {
1289 Value *CV = Chain[I];
1290 Value *V =
1291 Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName());
1292 if (V->getType() != CV->getType()) {
1293 V = Builder.CreateBitOrPointerCast(V, CV->getType());
1294 }
1295
1296 // Replace the old instruction.
1297 CV->replaceAllUsesWith(V);
1298 }
1299
1300 if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast))
1301 reorder(BitcastInst);
1302 }
1303
1304 eraseInstructions(Chain);
1305
1306 ++NumVectorInstructions;
1307 NumScalarsVectorized += Chain.size();
1308 return true;
1309 }
1310
accessIsMisaligned(unsigned SzInBytes,unsigned AddressSpace,Align Alignment)1311 bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace,
1312 Align Alignment) {
1313 if (Alignment.value() % SzInBytes == 0)
1314 return false;
1315
1316 bool Fast = false;
1317 bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(),
1318 SzInBytes * 8, AddressSpace,
1319 Alignment, &Fast);
1320 LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows
1321 << " and fast? " << Fast << "\n";);
1322 return !Allows || !Fast;
1323 }
1324