xref: /llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp (revision 2be0abb7fe72ed4537b3eabcd3102d48ea845717)
1 //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass merges loads/stores to/from sequential memory addresses into vector
10 // loads/stores.  Although there's nothing GPU-specific in here, this pass is
11 // motivated by the microarchitectural quirks of nVidia and AMD GPUs.
12 //
13 // (For simplicity below we talk about loads only, but everything also applies
14 // to stores.)
15 //
16 // This pass is intended to be run late in the pipeline, after other
17 // vectorization opportunities have been exploited.  So the assumption here is
18 // that immediately following our new vector load we'll need to extract out the
19 // individual elements of the load, so we can operate on them individually.
20 //
21 // On CPUs this transformation is usually not beneficial, because extracting the
22 // elements of a vector register is expensive on most architectures.  It's
23 // usually better just to load each element individually into its own scalar
24 // register.
25 //
26 // However, nVidia and AMD GPUs don't have proper vector registers.  Instead, a
27 // "vector load" loads directly into a series of scalar registers.  In effect,
28 // extracting the elements of the vector is free.  It's therefore always
29 // beneficial to vectorize a sequence of loads on these architectures.
30 //
31 // Vectorizing (perhaps a better name might be "coalescing") loads can have
32 // large performance impacts on GPU kernels, and opportunities for vectorizing
33 // are common in GPU code.  This pass tries very hard to find such
34 // opportunities; its runtime is quadratic in the number of loads in a BB.
35 //
36 // Some CPU architectures, such as ARM, have instructions that load into
37 // multiple scalar registers, similar to a GPU vectorized load.  In theory ARM
38 // could use this pass (with some modifications), but currently it implements
39 // its own pass to do something similar to what we do here.
40 //
41 // Overview of the algorithm and terminology in this pass:
42 //
43 //  - Break up each basic block into pseudo-BBs, composed of instructions which
44 //    are guaranteed to transfer control to their successors.
45 //  - Within a single pseudo-BB, find all loads, and group them into
46 //    "equivalence classes" according to getUnderlyingObject() and loaded
47 //    element size.  Do the same for stores.
48 //  - For each equivalence class, greedily build "chains".  Each chain has a
49 //    leader instruction, and every other member of the chain has a known
50 //    constant offset from the first instr in the chain.
51 //  - Break up chains so that they contain only contiguous accesses of legal
52 //    size with no intervening may-alias instrs.
53 //  - Convert each chain to vector instructions.
54 //
55 // The O(n^2) behavior of this pass comes from initially building the chains.
56 // In the worst case we have to compare each new instruction to all of those
57 // that came before. To limit this, we only calculate the offset to the leaders
58 // of the N most recently-used chains.
59 
60 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/MapVector.h"
65 #include "llvm/ADT/PostOrderIterator.h"
66 #include "llvm/ADT/STLExtras.h"
67 #include "llvm/ADT/Sequence.h"
68 #include "llvm/ADT/SmallPtrSet.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "llvm/ADT/Statistic.h"
71 #include "llvm/ADT/iterator_range.h"
72 #include "llvm/Analysis/AliasAnalysis.h"
73 #include "llvm/Analysis/AssumptionCache.h"
74 #include "llvm/Analysis/MemoryLocation.h"
75 #include "llvm/Analysis/ScalarEvolution.h"
76 #include "llvm/Analysis/TargetTransformInfo.h"
77 #include "llvm/Analysis/ValueTracking.h"
78 #include "llvm/Analysis/VectorUtils.h"
79 #include "llvm/IR/Attributes.h"
80 #include "llvm/IR/BasicBlock.h"
81 #include "llvm/IR/ConstantRange.h"
82 #include "llvm/IR/Constants.h"
83 #include "llvm/IR/DataLayout.h"
84 #include "llvm/IR/DerivedTypes.h"
85 #include "llvm/IR/Dominators.h"
86 #include "llvm/IR/Function.h"
87 #include "llvm/IR/GetElementPtrTypeIterator.h"
88 #include "llvm/IR/IRBuilder.h"
89 #include "llvm/IR/InstrTypes.h"
90 #include "llvm/IR/Instruction.h"
91 #include "llvm/IR/Instructions.h"
92 #include "llvm/IR/LLVMContext.h"
93 #include "llvm/IR/Module.h"
94 #include "llvm/IR/Type.h"
95 #include "llvm/IR/Value.h"
96 #include "llvm/InitializePasses.h"
97 #include "llvm/Pass.h"
98 #include "llvm/Support/Alignment.h"
99 #include "llvm/Support/Casting.h"
100 #include "llvm/Support/Debug.h"
101 #include "llvm/Support/KnownBits.h"
102 #include "llvm/Support/MathExtras.h"
103 #include "llvm/Support/ModRef.h"
104 #include "llvm/Support/raw_ostream.h"
105 #include "llvm/Transforms/Utils/Local.h"
106 #include "llvm/Transforms/Vectorize.h"
107 #include <algorithm>
108 #include <cassert>
109 #include <cstdint>
110 #include <cstdlib>
111 #include <iterator>
112 #include <limits>
113 #include <numeric>
114 #include <optional>
115 #include <tuple>
116 #include <type_traits>
117 #include <utility>
118 #include <vector>
119 
120 using namespace llvm;
121 
122 #define DEBUG_TYPE "load-store-vectorizer"
123 
124 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
125 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
126 
127 namespace {
128 
129 // Equivalence class key, the initial tuple by which we group loads/stores.
130 // Loads/stores with different EqClassKeys are never merged.
131 //
132 // (We could in theory remove element-size from the this tuple.  We'd just need
133 // to fix up the vector packing/unpacking code.)
134 using EqClassKey =
135     std::tuple<const Value * /* result of getUnderlyingObject() */,
136                unsigned /* AddrSpace */,
137                unsigned /* Load/Store element size bits */,
138                char /* IsLoad; char b/c bool can't be a DenseMap key */
139                >;
140 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const EqClassKey &K) {
141   const auto &[UnderlyingObject, AddrSpace, ElementSize, IsLoad] = K;
142   return OS << (IsLoad ? "load" : "store") << " of " << *UnderlyingObject
143             << " of element size " << ElementSize << " bits in addrspace "
144             << AddrSpace;
145 }
146 
147 // A Chain is a set of instructions such that:
148 //  - All instructions have the same equivalence class, so in particular all are
149 //    loads, or all are stores.
150 //  - We know the address accessed by the i'th chain elem relative to the
151 //    chain's leader instruction, which is the first instr of the chain in BB
152 //    order.
153 //
154 // Chains have two canonical orderings:
155 //  - BB order, sorted by Instr->comesBefore.
156 //  - Offset order, sorted by OffsetFromLeader.
157 // This pass switches back and forth between these orders.
158 struct ChainElem {
159   Instruction *Inst;
160   APInt OffsetFromLeader;
161 };
162 using Chain = SmallVector<ChainElem, 1>;
163 
164 void sortChainInBBOrder(Chain &C) {
165   sort(C, [](auto &A, auto &B) { return A.Inst->comesBefore(B.Inst); });
166 }
167 
168 void sortChainInOffsetOrder(Chain &C) {
169   sort(C, [](const auto &A, const auto &B) {
170     if (A.OffsetFromLeader != B.OffsetFromLeader)
171       return A.OffsetFromLeader.slt(B.OffsetFromLeader);
172     return A.Inst->comesBefore(B.Inst); // stable tiebreaker
173   });
174 }
175 
176 void dumpChain(ArrayRef<ChainElem> C) {
177   for (const auto &E : C) {
178     dbgs() << "  " << *E.Inst << " (offset " << E.OffsetFromLeader << ")\n";
179   }
180 }
181 
182 using EquivalenceClassMap =
183     MapVector<EqClassKey, SmallVector<Instruction *, 8>>;
184 
185 // FIXME: Assuming stack alignment of 4 is always good enough
186 constexpr unsigned StackAdjustedAlignment = 4;
187 
188 Instruction *propagateMetadata(Instruction *I, const Chain &C) {
189   SmallVector<Value *, 8> Values;
190   for (const ChainElem &E : C)
191     Values.push_back(E.Inst);
192   return propagateMetadata(I, Values);
193 }
194 
195 bool isInvariantLoad(const Instruction *I) {
196   const LoadInst *LI = dyn_cast<LoadInst>(I);
197   return LI != nullptr && LI->hasMetadata(LLVMContext::MD_invariant_load);
198 }
199 
200 /// Reorders the instructions that I depends on (the instructions defining its
201 /// operands), to ensure they dominate I.
202 void reorder(Instruction *I) {
203   SmallPtrSet<Instruction *, 16> InstructionsToMove;
204   SmallVector<Instruction *, 16> Worklist;
205 
206   Worklist.push_back(I);
207   while (!Worklist.empty()) {
208     Instruction *IW = Worklist.pop_back_val();
209     int NumOperands = IW->getNumOperands();
210     for (int i = 0; i < NumOperands; i++) {
211       Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i));
212       if (!IM || IM->getOpcode() == Instruction::PHI)
213         continue;
214 
215       // If IM is in another BB, no need to move it, because this pass only
216       // vectorizes instructions within one BB.
217       if (IM->getParent() != I->getParent())
218         continue;
219 
220       if (!IM->comesBefore(I)) {
221         InstructionsToMove.insert(IM);
222         Worklist.push_back(IM);
223       }
224     }
225   }
226 
227   // All instructions to move should follow I. Start from I, not from begin().
228   for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;) {
229     Instruction *IM = &*(BBI++);
230     if (!InstructionsToMove.count(IM))
231       continue;
232     IM->moveBefore(I);
233   }
234 }
235 
236 class Vectorizer {
237   Function &F;
238   AliasAnalysis &AA;
239   AssumptionCache &AC;
240   DominatorTree &DT;
241   ScalarEvolution &SE;
242   TargetTransformInfo &TTI;
243   const DataLayout &DL;
244   IRBuilder<> Builder;
245 
246   // We could erase instrs right after vectorizing them, but that can mess up
247   // our BB iterators, and also can make the equivalence class keys point to
248   // freed memory.  This is fixable, but it's simpler just to wait until we're
249   // done with the BB and erase all at once.
250   SmallVector<Instruction *, 128> ToErase;
251 
252 public:
253   Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
254              DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
255       : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI),
256         DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {}
257 
258   bool run();
259 
260 private:
261   static const unsigned MaxDepth = 3;
262 
263   /// Runs the vectorizer on a "pseudo basic block", which is a range of
264   /// instructions [Begin, End) within one BB all of which have
265   /// isGuaranteedToTransferExecutionToSuccessor(I) == true.
266   bool runOnPseudoBB(BasicBlock::iterator Begin, BasicBlock::iterator End);
267 
268   /// Runs the vectorizer on one equivalence class, i.e. one set of loads/stores
269   /// in the same BB with the same value for getUnderlyingObject() etc.
270   bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
271                              ArrayRef<Instruction *> EqClass);
272 
273   /// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
274   /// where all instructions access a known, constant offset from the first
275   /// instruction.
276   bool runOnChain(Chain &C);
277 
278   /// Splits the chain into subchains of instructions which read/write a
279   /// contiguous block of memory.  Discards any length-1 subchains (because
280   /// there's nothing to vectorize in there).
281   std::vector<Chain> splitChainByContiguity(Chain &C);
282 
283   /// Splits the chain into subchains where it's safe to hoist loads up to the
284   /// beginning of the sub-chain and it's safe to sink loads up to the end of
285   /// the sub-chain.  Discards any length-1 subchains.
286   std::vector<Chain> splitChainByMayAliasInstrs(Chain &C);
287 
288   /// Splits the chain into subchains that make legal, aligned accesses.
289   /// Discards any length-1 subchains.
290   std::vector<Chain> splitChainByAlignment(Chain &C);
291 
292   /// Converts the instrs in the chain into a single vectorized load or store.
293   /// Adds the old scalar loads/stores to ToErase.
294   bool vectorizeChain(Chain &C);
295 
296   /// Tries to compute the offset in bytes PtrB - PtrA.
297   std::optional<APInt> getConstantOffset(Value *PtrA, Value *PtrB,
298                                          unsigned Depth = 0);
299   std::optional<APInt> gtConstantOffsetComplexAddrs(Value *PtrA, Value *PtrB,
300                                                     unsigned Depth);
301   std::optional<APInt> getConstantOffsetSelects(Value *PtrA, Value *PtrB,
302                                                 unsigned Depth);
303 
304   /// Gets the element type of the vector that the chain will load or store.
305   /// This is nontrivial because the chain may contain elements of different
306   /// types; e.g. it's legal to have a chain that contains both i32 and float.
307   Type *getChainElemTy(const Chain &C);
308 
309   /// Determines whether ChainElem can be moved up (if IsLoad) or down (if
310   /// !IsLoad) to ChainBegin -- i.e. there are no intervening may-alias
311   /// instructions.
312   ///
313   /// The map ChainElemOffsets must contain all of the elements in
314   /// [ChainBegin, ChainElem] and their offsets from some arbitrary base
315   /// address.  It's ok if it contains additional entries.
316   template <bool IsLoadChain>
317   bool isSafeToMove(
318       Instruction *ChainElem, Instruction *ChainBegin,
319       const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
320 
321   /// Collects loads and stores grouped by "equivalence class", where:
322   ///   - all elements in an eq class are a load or all are a store,
323   ///   - they all load/store the same element size (it's OK to have e.g. i8 and
324   ///     <4 x i8> in the same class, but not i32 and <4 x i8>), and
325   ///   - they all have the same value for getUnderlyingObject().
326   EquivalenceClassMap collectEquivalenceClasses(BasicBlock::iterator Begin,
327                                                 BasicBlock::iterator End);
328 
329   /// Partitions Instrs into "chains" where every instruction has a known
330   /// constant offset from the first instr in the chain.
331   ///
332   /// Postcondition: For all i, ret[i][0].second == 0, because the first instr
333   /// in the chain is the leader, and an instr touches distance 0 from itself.
334   std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs);
335 };
336 
337 class LoadStoreVectorizerLegacyPass : public FunctionPass {
338 public:
339   static char ID;
340 
341   LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {
342     initializeLoadStoreVectorizerLegacyPassPass(
343         *PassRegistry::getPassRegistry());
344   }
345 
346   bool runOnFunction(Function &F) override;
347 
348   StringRef getPassName() const override {
349     return "GPU Load and Store Vectorizer";
350   }
351 
352   void getAnalysisUsage(AnalysisUsage &AU) const override {
353     AU.addRequired<AAResultsWrapperPass>();
354     AU.addRequired<AssumptionCacheTracker>();
355     AU.addRequired<ScalarEvolutionWrapperPass>();
356     AU.addRequired<DominatorTreeWrapperPass>();
357     AU.addRequired<TargetTransformInfoWrapperPass>();
358     AU.setPreservesCFG();
359   }
360 };
361 
362 } // end anonymous namespace
363 
364 char LoadStoreVectorizerLegacyPass::ID = 0;
365 
366 INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
367                       "Vectorize load and Store instructions", false, false)
368 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
369 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker);
370 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
372 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
373 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
374 INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
375                     "Vectorize load and store instructions", false, false)
376 
377 Pass *llvm::createLoadStoreVectorizerPass() {
378   return new LoadStoreVectorizerLegacyPass();
379 }
380 
381 bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
382   // Don't vectorize when the attribute NoImplicitFloat is used.
383   if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
384     return false;
385 
386   AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
387   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
388   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
389   TargetTransformInfo &TTI =
390       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
391 
392   AssumptionCache &AC =
393       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
394 
395   return Vectorizer(F, AA, AC, DT, SE, TTI).run();
396 }
397 
398 PreservedAnalyses LoadStoreVectorizerPass::run(Function &F,
399                                                FunctionAnalysisManager &AM) {
400   // Don't vectorize when the attribute NoImplicitFloat is used.
401   if (F.hasFnAttribute(Attribute::NoImplicitFloat))
402     return PreservedAnalyses::all();
403 
404   AliasAnalysis &AA = AM.getResult<AAManager>(F);
405   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
406   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
407   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
408   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
409 
410   bool Changed = Vectorizer(F, AA, AC, DT, SE, TTI).run();
411   PreservedAnalyses PA;
412   PA.preserveSet<CFGAnalyses>();
413   return Changed ? PA : PreservedAnalyses::all();
414 }
415 
416 bool Vectorizer::run() {
417   bool Changed = false;
418   // Break up the BB if there are any instrs which aren't guaranteed to transfer
419   // execution to their successor.
420   //
421   // Consider, for example:
422   //
423   //   def assert_arr_len(int n) { if (n < 2) exit(); }
424   //
425   //   load arr[0]
426   //   call assert_array_len(arr.length)
427   //   load arr[1]
428   //
429   // Even though assert_arr_len does not read or write any memory, we can't
430   // speculate the second load before the call.  More info at
431   // https://github.com/llvm/llvm-project/issues/52950.
432   for (BasicBlock *BB : post_order(&F)) {
433     // BB must at least have a terminator.
434     assert(!BB->empty());
435 
436     SmallVector<BasicBlock::iterator, 8> Barriers;
437     Barriers.push_back(BB->begin());
438     for (Instruction &I : *BB)
439       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
440         Barriers.push_back(I.getIterator());
441     Barriers.push_back(BB->end());
442 
443     for (auto It = Barriers.begin(), End = std::prev(Barriers.end()); It != End;
444          ++It)
445       Changed |= runOnPseudoBB(*It, *std::next(It));
446 
447     for (Instruction *I : ToErase) {
448       auto *PtrOperand = getLoadStorePointerOperand(I);
449       if (I->use_empty())
450         I->eraseFromParent();
451       RecursivelyDeleteTriviallyDeadInstructions(PtrOperand);
452     }
453     ToErase.clear();
454   }
455 
456   return Changed;
457 }
458 
459 bool Vectorizer::runOnPseudoBB(BasicBlock::iterator Begin,
460                                BasicBlock::iterator End) {
461   LLVM_DEBUG({
462     dbgs() << "LSV: Running on pseudo-BB [" << *Begin << " ... ";
463     if (End != Begin->getParent()->end())
464       dbgs() << *End;
465     else
466       dbgs() << "<BB end>";
467     dbgs() << ")\n";
468   });
469 
470   bool Changed = false;
471   for (const auto &[EqClassKey, EqClass] :
472        collectEquivalenceClasses(Begin, End))
473     Changed |= runOnEquivalenceClass(EqClassKey, EqClass);
474 
475   return Changed;
476 }
477 
478 bool Vectorizer::runOnEquivalenceClass(const EqClassKey &EqClassKey,
479                                        ArrayRef<Instruction *> EqClass) {
480   bool Changed = false;
481 
482   LLVM_DEBUG({
483     dbgs() << "LSV: Running on equivalence class of size " << EqClass.size()
484            << " keyed on " << EqClassKey << ":\n";
485     for (Instruction *I : EqClass)
486       dbgs() << "  " << *I << "\n";
487   });
488 
489   std::vector<Chain> Chains = gatherChains(EqClass);
490   LLVM_DEBUG(dbgs() << "LSV: Got " << Chains.size()
491                     << " nontrivial chains.\n";);
492   for (Chain &C : Chains)
493     Changed |= runOnChain(C);
494   return Changed;
495 }
496 
497 bool Vectorizer::runOnChain(Chain &C) {
498   LLVM_DEBUG({
499     dbgs() << "LSV: Running on chain with " << C.size() << " instructions:\n";
500     dumpChain(C);
501   });
502 
503   // Split up the chain into increasingly smaller chains, until we can finally
504   // vectorize the chains.
505   //
506   // (Don't be scared by the depth of the loop nest here.  These operations are
507   // all at worst O(n lg n) in the number of instructions, and splitting chains
508   // doesn't change the number of instrs.  So the whole loop nest is O(n lg n).)
509   bool Changed = false;
510   for (auto &C : splitChainByMayAliasInstrs(C))
511     for (auto &C : splitChainByContiguity(C))
512       for (auto &C : splitChainByAlignment(C))
513         Changed |= vectorizeChain(C);
514   return Changed;
515 }
516 
517 std::vector<Chain> Vectorizer::splitChainByMayAliasInstrs(Chain &C) {
518   if (C.empty())
519     return {};
520 
521   sortChainInBBOrder(C);
522 
523   LLVM_DEBUG({
524     dbgs() << "LSV: splitChainByMayAliasInstrs considering chain:\n";
525     dumpChain(C);
526   });
527 
528   // We know that elements in the chain with nonverlapping offsets can't
529   // alias, but AA may not be smart enough to figure this out.  Use a
530   // hashtable.
531   DenseMap<Instruction *, APInt /*OffsetFromLeader*/> ChainOffsets;
532   for (const auto &E : C)
533     ChainOffsets.insert({&*E.Inst, E.OffsetFromLeader});
534 
535   // Loads get hoisted up to the first load in the chain.  Stores get sunk
536   // down to the last store in the chain.  Our algorithm for loads is:
537   //
538   //  - Take the first element of the chain.  This is the start of a new chain.
539   //  - Take the next element of `Chain` and check for may-alias instructions
540   //    up to the start of NewChain.  If no may-alias instrs, add it to
541   //    NewChain.  Otherwise, start a new NewChain.
542   //
543   // For stores it's the same except in the reverse direction.
544   //
545   // We expect IsLoad to be an std::bool_constant.
546   auto Impl = [&](auto IsLoad) {
547     // MSVC is unhappy if IsLoad is a capture, so pass it as an arg.
548     auto [ChainBegin, ChainEnd] = [&](auto IsLoad) {
549       if constexpr (IsLoad())
550         return std::make_pair(C.begin(), C.end());
551       else
552         return std::make_pair(C.rbegin(), C.rend());
553     }(IsLoad);
554     assert(ChainBegin != ChainEnd);
555 
556     std::vector<Chain> Chains;
557     SmallVector<ChainElem, 1> NewChain;
558     NewChain.push_back(*ChainBegin);
559     for (auto ChainIt = std::next(ChainBegin); ChainIt != ChainEnd; ++ChainIt) {
560       if (isSafeToMove<IsLoad>(ChainIt->Inst, NewChain.front().Inst,
561                                ChainOffsets)) {
562         LLVM_DEBUG(dbgs() << "LSV: No intervening may-alias instrs; can merge "
563                           << *ChainIt->Inst << " into " << *ChainBegin->Inst
564                           << "\n");
565         NewChain.push_back(*ChainIt);
566       } else {
567         LLVM_DEBUG(
568             dbgs() << "LSV: Found intervening may-alias instrs; cannot merge "
569                    << *ChainIt->Inst << " into " << *ChainBegin->Inst << "\n");
570         if (NewChain.size() > 1) {
571           LLVM_DEBUG({
572             dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
573             dumpChain(NewChain);
574           });
575           Chains.push_back(std::move(NewChain));
576         }
577 
578         // Start a new chain.
579         NewChain = SmallVector<ChainElem, 1>({*ChainIt});
580       }
581     }
582     if (NewChain.size() > 1) {
583       LLVM_DEBUG({
584         dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
585         dumpChain(NewChain);
586       });
587       Chains.push_back(std::move(NewChain));
588     }
589     return Chains;
590   };
591 
592   if (isa<LoadInst>(C[0].Inst))
593     return Impl(/*IsLoad=*/std::bool_constant<true>());
594 
595   assert(isa<StoreInst>(C[0].Inst));
596   return Impl(/*IsLoad=*/std::bool_constant<false>());
597 }
598 
599 std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
600   if (C.empty())
601     return {};
602 
603   sortChainInOffsetOrder(C);
604 
605   LLVM_DEBUG({
606     dbgs() << "LSV: splitChainByContiguity considering chain:\n";
607     dumpChain(C);
608   });
609 
610   std::vector<Chain> Ret;
611   Ret.push_back({C.front()});
612 
613   for (auto It = std::next(C.begin()), End = C.end(); It != End; ++It) {
614     // `prev` accesses offsets [PrevDistFromBase, PrevReadEnd).
615     auto &CurChain = Ret.back();
616     const ChainElem &Prev = CurChain.back();
617     unsigned SzBits = DL.getTypeSizeInBits(getLoadStoreType(&*Prev.Inst));
618     assert(SzBits % 8 == 0 && "Non-byte sizes should have been filtered out by "
619                               "collectEquivalenceClass");
620     APInt PrevReadEnd = Prev.OffsetFromLeader + SzBits / 8;
621 
622     // Add this instruction to the end of the current chain, or start a new one.
623     bool AreContiguous = It->OffsetFromLeader == PrevReadEnd;
624     LLVM_DEBUG(dbgs() << "LSV: Instructions are "
625                       << (AreContiguous ? "" : "not ") << "contiguous: "
626                       << *Prev.Inst << " (ends at offset " << PrevReadEnd
627                       << ") -> " << *It->Inst << " (starts at offset "
628                       << It->OffsetFromLeader << ")\n");
629     if (AreContiguous)
630       CurChain.push_back(*It);
631     else
632       Ret.push_back({*It});
633   }
634 
635   // Filter out length-1 chains, these are uninteresting.
636   llvm::erase_if(Ret, [](const auto &Chain) { return Chain.size() <= 1; });
637   return Ret;
638 }
639 
640 Type *Vectorizer::getChainElemTy(const Chain &C) {
641   assert(!C.empty());
642   // The rules are:
643   //  - If there are any pointer types in the chain, use an integer type.
644   //  - Prefer an integer type if it appears in the chain.
645   //  - Otherwise, use the first type in the chain.
646   //
647   // The rule about pointer types is a simplification when we merge e.g.  a load
648   // of a ptr and a double.  There's no direct conversion from a ptr to a
649   // double; it requires a ptrtoint followed by a bitcast.
650   //
651   // It's unclear to me if the other rules have any practical effect, but we do
652   // it to match this pass's previous behavior.
653   if (any_of(C, [](const ChainElem &E) {
654         return getLoadStoreType(E.Inst)->getScalarType()->isPointerTy();
655       })) {
656     return Type::getIntNTy(
657         F.getContext(),
658         DL.getTypeSizeInBits(getLoadStoreType(C[0].Inst)->getScalarType()));
659   }
660 
661   for (const ChainElem &E : C)
662     if (Type *T = getLoadStoreType(E.Inst)->getScalarType(); T->isIntegerTy())
663       return T;
664   return getLoadStoreType(C[0].Inst)->getScalarType();
665 }
666 
667 std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
668   // We use a simple greedy algorithm.
669   //  - Given a chain of length N, find all prefixes that
670   //    (a) are not longer than the max register length, and
671   //    (b) are a power of 2.
672   //  - Starting from the longest prefix, try to create a vector of that length.
673   //  - If one of them works, great.  Repeat the algorithm on any remaining
674   //    elements in the chain.
675   //  - If none of them work, discard the first element and repeat on a chain
676   //    of length N-1.
677   if (C.empty())
678     return {};
679 
680   sortChainInOffsetOrder(C);
681 
682   LLVM_DEBUG({
683     dbgs() << "LSV: splitChainByAlignment considering chain:\n";
684     dumpChain(C);
685   });
686 
687   bool IsLoadChain = isa<LoadInst>(C[0].Inst);
688   auto getVectorFactor = [&](unsigned VF, unsigned LoadStoreSize,
689                              unsigned ChainSizeBytes, VectorType *VecTy) {
690     return IsLoadChain ? TTI.getLoadVectorFactor(VF, LoadStoreSize,
691                                                  ChainSizeBytes, VecTy)
692                        : TTI.getStoreVectorFactor(VF, LoadStoreSize,
693                                                   ChainSizeBytes, VecTy);
694   };
695 
696 #ifndef NDEBUG
697   for (const auto &E : C) {
698     Type *Ty = getLoadStoreType(E.Inst)->getScalarType();
699     assert(isPowerOf2_32(DL.getTypeSizeInBits(Ty)) &&
700            "Should have filtered out non-power-of-two elements in "
701            "collectEquivalenceClasses.");
702   }
703 #endif
704 
705   unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
706   unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AS) / 8;
707 
708   std::vector<Chain> Ret;
709   for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) {
710     // Find candidate chains of size not greater than the largest vector reg.
711     // These chains are over the closed interval [CBegin, CEnd].
712     SmallVector<std::pair<unsigned /*CEnd*/, unsigned /*SizeBytes*/>, 8>
713         CandidateChains;
714     for (unsigned CEnd = CBegin + 1, Size = C.size(); CEnd < Size; ++CEnd) {
715       APInt Sz = C[CEnd].OffsetFromLeader +
716                  DL.getTypeStoreSize(getLoadStoreType(C[CEnd].Inst)) -
717                  C[CBegin].OffsetFromLeader;
718       if (Sz.sgt(VecRegBytes))
719         break;
720       CandidateChains.push_back(
721           {CEnd, static_cast<unsigned>(Sz.getLimitedValue())});
722     }
723 
724     // Consider the longest chain first.
725     for (auto It = CandidateChains.rbegin(), End = CandidateChains.rend();
726          It != End; ++It) {
727       auto [CEnd, SizeBytes] = *It;
728       LLVM_DEBUG(
729           dbgs() << "LSV: splitChainByAlignment considering candidate chain ["
730                  << *C[CBegin].Inst << " ... " << *C[CEnd].Inst << "]\n");
731 
732       Type *VecElemTy = getChainElemTy(C);
733       // Note, VecElemTy is a power of 2, but might be less than one byte.  For
734       // example, we can vectorize 2 x <2 x i4> to <4 x i4>, and in this case
735       // VecElemTy would be i4.
736       unsigned VecElemBits = DL.getTypeSizeInBits(VecElemTy);
737 
738       // SizeBytes and VecElemBits are powers of 2, so they divide evenly.
739       assert((8 * SizeBytes) % VecElemBits == 0);
740       unsigned NumVecElems = 8 * SizeBytes / VecElemBits;
741       FixedVectorType *VecTy = FixedVectorType::get(VecElemTy, NumVecElems);
742       unsigned VF = 8 * VecRegBytes / VecElemBits;
743 
744       // Check that TTI is happy with this vectorization factor.
745       unsigned TargetVF = getVectorFactor(VF, VecElemBits,
746                                           VecElemBits * NumVecElems / 8, VecTy);
747       if (TargetVF != VF && TargetVF < NumVecElems) {
748         LLVM_DEBUG(
749             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
750                       "because TargetVF="
751                    << TargetVF << " != VF=" << VF
752                    << " and TargetVF < NumVecElems=" << NumVecElems << "\n");
753         continue;
754       }
755 
756       // Is a load/store with this alignment allowed by TTI and at least as fast
757       // as an unvectorized load/store?
758       //
759       // TTI and F are passed as explicit captures to WAR an MSVC misparse (??).
760       auto IsAllowedAndFast = [&, SizeBytes = SizeBytes, &TTI = TTI,
761                                &F = F](Align Alignment) {
762         if (Alignment.value() % SizeBytes == 0)
763           return true;
764         unsigned VectorizedSpeed = 0;
765         bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses(
766             F.getContext(), SizeBytes * 8, AS, Alignment, &VectorizedSpeed);
767         if (!AllowsMisaligned) {
768           LLVM_DEBUG(dbgs()
769                      << "LSV: Access of " << SizeBytes << "B in addrspace "
770                      << AS << " with alignment " << Alignment.value()
771                      << " is misaligned, and therefore can't be vectorized.\n");
772           return false;
773         }
774 
775         unsigned ElementwiseSpeed = 0;
776         (TTI).allowsMisalignedMemoryAccesses((F).getContext(), VecElemBits, AS,
777                                              Alignment, &ElementwiseSpeed);
778         if (VectorizedSpeed < ElementwiseSpeed) {
779           LLVM_DEBUG(dbgs()
780                      << "LSV: Access of " << SizeBytes << "B in addrspace "
781                      << AS << " with alignment " << Alignment.value()
782                      << " has relative speed " << VectorizedSpeed
783                      << ", which is lower than the elementwise speed of "
784                      << ElementwiseSpeed
785                      << ".  Therefore this access won't be vectorized.\n");
786           return false;
787         }
788         return true;
789       };
790 
791       // If we're loading/storing from an alloca, align it if possible.
792       //
793       // FIXME: We eagerly upgrade the alignment, regardless of whether TTI
794       // tells us this is beneficial.  This feels a bit odd, but it matches
795       // existing tests.  This isn't *so* bad, because at most we align to 4
796       // bytes (current value of StackAdjustedAlignment).
797       //
798       // FIXME: We will upgrade the alignment of the alloca even if it turns out
799       // we can't vectorize for some other reason.
800       Align Alignment = getLoadStoreAlignment(C[CBegin].Inst);
801       if (AS == DL.getAllocaAddrSpace() && Alignment.value() % SizeBytes != 0 &&
802           IsAllowedAndFast(Align(StackAdjustedAlignment))) {
803         Align NewAlign = getOrEnforceKnownAlignment(
804             getLoadStorePointerOperand(C[CBegin].Inst),
805             Align(StackAdjustedAlignment), DL, C[CBegin].Inst, nullptr, &DT);
806         if (NewAlign >= Alignment) {
807           LLVM_DEBUG(dbgs()
808                      << "LSV: splitByChain upgrading alloca alignment from "
809                      << Alignment.value() << " to " << NewAlign.value()
810                      << "\n");
811           Alignment = NewAlign;
812         }
813       }
814 
815       if (!IsAllowedAndFast(Alignment)) {
816         LLVM_DEBUG(
817             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
818                       "because its alignment is not AllowedAndFast: "
819                    << Alignment.value() << "\n");
820         continue;
821       }
822 
823       if ((IsLoadChain &&
824            !TTI.isLegalToVectorizeLoadChain(SizeBytes, Alignment, AS)) ||
825           (!IsLoadChain &&
826            !TTI.isLegalToVectorizeStoreChain(SizeBytes, Alignment, AS))) {
827         LLVM_DEBUG(
828             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
829                       "because !isLegalToVectorizeLoad/StoreChain.");
830         continue;
831       }
832 
833       // Hooray, we can vectorize this chain!
834       Chain &NewChain = Ret.emplace_back();
835       for (unsigned I = CBegin; I <= CEnd; ++I)
836         NewChain.push_back(C[I]);
837       CBegin = CEnd; // Skip over the instructions we've added to the chain.
838       break;
839     }
840   }
841   return Ret;
842 }
843 
844 bool Vectorizer::vectorizeChain(Chain &C) {
845   if (C.size() < 2)
846     return false;
847 
848   sortChainInOffsetOrder(C);
849 
850   LLVM_DEBUG({
851     dbgs() << "LSV: Vectorizing chain of " << C.size() << " instructions:\n";
852     dumpChain(C);
853   });
854 
855   Type *VecElemTy = getChainElemTy(C);
856   bool IsLoadChain = isa<LoadInst>(C[0].Inst);
857   unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
858   unsigned ChainBytes = std::accumulate(
859       C.begin(), C.end(), 0u, [&](unsigned Bytes, const ChainElem &E) {
860         return Bytes + DL.getTypeStoreSize(getLoadStoreType(E.Inst));
861       });
862   assert(ChainBytes % DL.getTypeStoreSize(VecElemTy) == 0);
863   // VecTy is a power of 2 and 1 byte at smallest, but VecElemTy may be smaller
864   // than 1 byte (e.g. VecTy == <32 x i1>).
865   Type *VecTy = FixedVectorType::get(
866       VecElemTy, 8 * ChainBytes / DL.getTypeSizeInBits(VecElemTy));
867 
868   Align Alignment = getLoadStoreAlignment(C[0].Inst);
869   // If this is a load/store of an alloca, we might have upgraded the alloca's
870   // alignment earlier.  Get the new alignment.
871   if (AS == DL.getAllocaAddrSpace()) {
872     Alignment = std::max(
873         Alignment,
874         getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst),
875                                    MaybeAlign(), DL, C[0].Inst, nullptr, &DT));
876   }
877 
878   // All elements of the chain must have the same scalar-type size.
879 #ifndef NDEBUG
880   for (const ChainElem &E : C)
881     assert(DL.getTypeStoreSize(getLoadStoreType(E.Inst)->getScalarType()) ==
882            DL.getTypeStoreSize(VecElemTy));
883 #endif
884 
885   Instruction *VecInst;
886   if (IsLoadChain) {
887     // Loads get hoisted to the location of the first load in the chain.  We may
888     // also need to hoist the (transitive) operands of the loads.
889     Builder.SetInsertPoint(
890         std::min_element(C.begin(), C.end(), [](const auto &A, const auto &B) {
891           return A.Inst->comesBefore(B.Inst);
892         })->Inst);
893 
894     // Chain is in offset order, so C[0] is the instr with the lowest offset,
895     // i.e. the root of the vector.
896     Value *Bitcast = Builder.CreateBitCast(
897         getLoadStorePointerOperand(C[0].Inst), VecTy->getPointerTo(AS));
898     VecInst = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment);
899 
900     unsigned VecIdx = 0;
901     for (const ChainElem &E : C) {
902       Instruction *I = E.Inst;
903       Value *V;
904       Type *T = getLoadStoreType(I);
905       if (auto *VT = dyn_cast<FixedVectorType>(T)) {
906         auto Mask = llvm::to_vector<8>(
907             llvm::seq<int>(VecIdx, VecIdx + VT->getNumElements()));
908         V = Builder.CreateShuffleVector(VecInst, Mask, I->getName());
909         VecIdx += VT->getNumElements();
910       } else {
911         V = Builder.CreateExtractElement(VecInst, Builder.getInt32(VecIdx),
912                                          I->getName());
913         ++VecIdx;
914       }
915       if (V->getType() != I->getType())
916         V = Builder.CreateBitOrPointerCast(V, I->getType());
917       I->replaceAllUsesWith(V);
918     }
919 
920     // Finally, we need to reorder the instrs in the BB so that the (transitive)
921     // operands of VecInst appear before it.  To see why, suppose we have
922     // vectorized the following code:
923     //
924     //   ptr1  = gep a, 1
925     //   load1 = load i32 ptr1
926     //   ptr0  = gep a, 0
927     //   load0 = load i32 ptr0
928     //
929     // We will put the vectorized load at the location of the earliest load in
930     // the BB, i.e. load1.  We get:
931     //
932     //   ptr1  = gep a, 1
933     //   loadv = load <2 x i32> ptr0
934     //   load0 = extractelement loadv, 0
935     //   load1 = extractelement loadv, 1
936     //   ptr0 = gep a, 0
937     //
938     // Notice that loadv uses ptr0, which is defined *after* it!
939     reorder(VecInst);
940   } else {
941     // Stores get sunk to the location of the last store in the chain.
942     Builder.SetInsertPoint(
943         std::max_element(C.begin(), C.end(), [](auto &A, auto &B) {
944           return A.Inst->comesBefore(B.Inst);
945         })->Inst);
946 
947     // Build the vector to store.
948     Value *Vec = PoisonValue::get(VecTy);
949     unsigned VecIdx = 0;
950     auto InsertElem = [&](Value *V) {
951       if (V->getType() != VecElemTy)
952         V = Builder.CreateBitOrPointerCast(V, VecElemTy);
953       Vec = Builder.CreateInsertElement(Vec, V, Builder.getInt32(VecIdx++));
954     };
955     for (const ChainElem &E : C) {
956       auto I = cast<StoreInst>(E.Inst);
957       if (FixedVectorType *VT =
958               dyn_cast<FixedVectorType>(getLoadStoreType(I))) {
959         for (int J = 0, JE = VT->getNumElements(); J < JE; ++J) {
960           InsertElem(Builder.CreateExtractElement(I->getValueOperand(),
961                                                   Builder.getInt32(J)));
962         }
963       } else {
964         InsertElem(I->getValueOperand());
965       }
966     }
967 
968     // Chain is in offset order, so C[0] is the instr with the lowest offset,
969     // i.e. the root of the vector.
970     VecInst = Builder.CreateAlignedStore(
971         Vec,
972         Builder.CreateBitCast(getLoadStorePointerOperand(C[0].Inst),
973                               VecTy->getPointerTo(AS)),
974         Alignment);
975   }
976 
977   propagateMetadata(VecInst, C);
978 
979   for (const ChainElem &E : C)
980     ToErase.push_back(E.Inst);
981 
982   ++NumVectorInstructions;
983   NumScalarsVectorized += C.size();
984   return true;
985 }
986 
987 template <bool IsLoadChain>
988 bool Vectorizer::isSafeToMove(
989     Instruction *ChainElem, Instruction *ChainBegin,
990     const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets) {
991   LLVM_DEBUG(dbgs() << "LSV: isSafeToMove(" << *ChainElem << " -> "
992                     << *ChainBegin << ")\n");
993 
994   assert(isa<LoadInst>(ChainElem) == IsLoadChain);
995   if (ChainElem == ChainBegin)
996     return true;
997 
998   // Invariant loads can always be reordered; by definition they are not
999   // clobbered by stores.
1000   if (isInvariantLoad(ChainElem))
1001     return true;
1002 
1003   auto BBIt = std::next([&] {
1004     if constexpr (IsLoadChain)
1005       return BasicBlock::reverse_iterator(ChainElem);
1006     else
1007       return BasicBlock::iterator(ChainElem);
1008   }());
1009   auto BBItEnd = std::next([&] {
1010     if constexpr (IsLoadChain)
1011       return BasicBlock::reverse_iterator(ChainBegin);
1012     else
1013       return BasicBlock::iterator(ChainBegin);
1014   }());
1015 
1016   const APInt &ChainElemOffset = ChainOffsets.at(ChainElem);
1017   const unsigned ChainElemSize =
1018       DL.getTypeStoreSize(getLoadStoreType(ChainElem));
1019 
1020   for (; BBIt != BBItEnd; ++BBIt) {
1021     Instruction *I = &*BBIt;
1022 
1023     if (!I->mayReadOrWriteMemory())
1024       continue;
1025 
1026     // Loads can be reordered with other loads.
1027     if (IsLoadChain && isa<LoadInst>(I))
1028       continue;
1029 
1030     // Stores can be sunk below invariant loads.
1031     if (!IsLoadChain && isInvariantLoad(I))
1032       continue;
1033 
1034     // If I is in the chain, we can tell whether it aliases ChainIt by checking
1035     // what offset ChainIt accesses.  This may be better than AA is able to do.
1036     //
1037     // We should really only have duplicate offsets for stores (the duplicate
1038     // loads should be CSE'ed), but in case we have a duplicate load, we'll
1039     // split the chain so we don't have to handle this case specially.
1040     if (auto OffsetIt = ChainOffsets.find(I); OffsetIt != ChainOffsets.end()) {
1041       // I and ChainElem overlap if:
1042       //   - I and ChainElem have the same offset, OR
1043       //   - I's offset is less than ChainElem's, but I touches past the
1044       //     beginning of ChainElem, OR
1045       //   - ChainElem's offset is less than I's, but ChainElem touches past the
1046       //     beginning of I.
1047       const APInt &IOffset = OffsetIt->second;
1048       unsigned IElemSize = DL.getTypeStoreSize(getLoadStoreType(I));
1049       if (IOffset == ChainElemOffset ||
1050           (IOffset.sle(ChainElemOffset) &&
1051            (IOffset + IElemSize).sgt(ChainElemOffset)) ||
1052           (ChainElemOffset.sle(IOffset) &&
1053            (ChainElemOffset + ChainElemSize).sgt(OffsetIt->second))) {
1054         LLVM_DEBUG({
1055           // Double check that AA also sees this alias.  If not, we probably
1056           // have a bug.
1057           ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
1058           assert(IsLoadChain ? isModSet(MR) : isModOrRefSet(MR));
1059           dbgs() << "LSV: Found alias in chain: " << *I << "\n";
1060         });
1061         return false; // We found an aliasing instruction; bail.
1062       }
1063 
1064       continue; // We're confident there's no alias.
1065     }
1066 
1067     LLVM_DEBUG(dbgs() << "LSV: Querying AA for " << *I << "\n");
1068     ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
1069     if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) {
1070       LLVM_DEBUG(dbgs() << "LSV: Found alias in chain:\n"
1071                         << "  Aliasing instruction:\n"
1072                         << "    " << *I << '\n'
1073                         << "  Aliased instruction and pointer:\n"
1074                         << "    " << *ChainElem << '\n'
1075                         << "    " << *getLoadStorePointerOperand(ChainElem)
1076                         << '\n');
1077 
1078       return false;
1079     }
1080   }
1081   return true;
1082 }
1083 
1084 static bool checkNoWrapFlags(Instruction *I, bool Signed) {
1085   BinaryOperator *BinOpI = cast<BinaryOperator>(I);
1086   return (Signed && BinOpI->hasNoSignedWrap()) ||
1087          (!Signed && BinOpI->hasNoUnsignedWrap());
1088 }
1089 
1090 static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
1091                                    unsigned MatchingOpIdxA, Instruction *AddOpB,
1092                                    unsigned MatchingOpIdxB, bool Signed) {
1093   LLVM_DEBUG(dbgs() << "LSV: checkIfSafeAddSequence IdxDiff=" << IdxDiff
1094                     << ", AddOpA=" << *AddOpA << ", MatchingOpIdxA="
1095                     << MatchingOpIdxA << ", AddOpB=" << *AddOpB
1096                     << ", MatchingOpIdxB=" << MatchingOpIdxB
1097                     << ", Signed=" << Signed << "\n");
1098   // If both OpA and OpB are adds with NSW/NUW and with one of the operands
1099   // being the same, we can guarantee that the transformation is safe if we can
1100   // prove that OpA won't overflow when Ret added to the other operand of OpA.
1101   // For example:
1102   //  %tmp7 = add nsw i32 %tmp2, %v0
1103   //  %tmp8 = sext i32 %tmp7 to i64
1104   //  ...
1105   //  %tmp11 = add nsw i32 %v0, 1
1106   //  %tmp12 = add nsw i32 %tmp2, %tmp11
1107   //  %tmp13 = sext i32 %tmp12 to i64
1108   //
1109   //  Both %tmp7 and %tmp12 have the nsw flag and the first operand is %tmp2.
1110   //  It's guaranteed that adding 1 to %tmp7 won't overflow because %tmp11 adds
1111   //  1 to %v0 and both %tmp11 and %tmp12 have the nsw flag.
1112   assert(AddOpA->getOpcode() == Instruction::Add &&
1113          AddOpB->getOpcode() == Instruction::Add &&
1114          checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
1115   if (AddOpA->getOperand(MatchingOpIdxA) ==
1116       AddOpB->getOperand(MatchingOpIdxB)) {
1117     Value *OtherOperandA = AddOpA->getOperand(MatchingOpIdxA == 1 ? 0 : 1);
1118     Value *OtherOperandB = AddOpB->getOperand(MatchingOpIdxB == 1 ? 0 : 1);
1119     Instruction *OtherInstrA = dyn_cast<Instruction>(OtherOperandA);
1120     Instruction *OtherInstrB = dyn_cast<Instruction>(OtherOperandB);
1121     // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
1122     if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add &&
1123         checkNoWrapFlags(OtherInstrB, Signed) &&
1124         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
1125       int64_t CstVal =
1126           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
1127       if (OtherInstrB->getOperand(0) == OtherOperandA &&
1128           IdxDiff.getSExtValue() == CstVal)
1129         return true;
1130     }
1131     // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
1132     if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add &&
1133         checkNoWrapFlags(OtherInstrA, Signed) &&
1134         isa<ConstantInt>(OtherInstrA->getOperand(1))) {
1135       int64_t CstVal =
1136           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
1137       if (OtherInstrA->getOperand(0) == OtherOperandB &&
1138           IdxDiff.getSExtValue() == -CstVal)
1139         return true;
1140     }
1141     // Match `x +nsw/nuw (y +nsw/nuw c)` and
1142     // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
1143     if (OtherInstrA && OtherInstrB &&
1144         OtherInstrA->getOpcode() == Instruction::Add &&
1145         OtherInstrB->getOpcode() == Instruction::Add &&
1146         checkNoWrapFlags(OtherInstrA, Signed) &&
1147         checkNoWrapFlags(OtherInstrB, Signed) &&
1148         isa<ConstantInt>(OtherInstrA->getOperand(1)) &&
1149         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
1150       int64_t CstValA =
1151           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
1152       int64_t CstValB =
1153           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
1154       if (OtherInstrA->getOperand(0) == OtherInstrB->getOperand(0) &&
1155           IdxDiff.getSExtValue() == (CstValB - CstValA))
1156         return true;
1157     }
1158   }
1159   return false;
1160 }
1161 
1162 std::optional<APInt> Vectorizer::gtConstantOffsetComplexAddrs(Value *PtrA,
1163                                                               Value *PtrB,
1164                                                               unsigned Depth) {
1165   LLVM_DEBUG(dbgs() << "LSV: gtConstantOffsetComplexAddrs PtrA=" << *PtrA
1166                     << " PtrB=" << *PtrB << " Depth=" << Depth << "\n");
1167   auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
1168   auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
1169   if (!GEPA || !GEPB)
1170     return getConstantOffsetSelects(PtrA, PtrB, Depth);
1171 
1172   // Look through GEPs after checking they're the same except for the last
1173   // index.
1174   if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
1175       GEPA->getPointerOperand() != GEPB->getPointerOperand())
1176     return std::nullopt;
1177   gep_type_iterator GTIA = gep_type_begin(GEPA);
1178   gep_type_iterator GTIB = gep_type_begin(GEPB);
1179   for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
1180     if (GTIA.getOperand() != GTIB.getOperand())
1181       return std::nullopt;
1182     ++GTIA;
1183     ++GTIB;
1184   }
1185 
1186   Instruction *OpA = dyn_cast<Instruction>(GTIA.getOperand());
1187   Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand());
1188   if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
1189       OpA->getType() != OpB->getType())
1190     return std::nullopt;
1191 
1192   uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType());
1193 
1194   // Only look through a ZExt/SExt.
1195   if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
1196     return std::nullopt;
1197 
1198   bool Signed = isa<SExtInst>(OpA);
1199 
1200   // At this point A could be a function parameter, i.e. not an instruction
1201   Value *ValA = OpA->getOperand(0);
1202   OpB = dyn_cast<Instruction>(OpB->getOperand(0));
1203   if (!OpB || ValA->getType() != OpB->getType())
1204     return std::nullopt;
1205 
1206   const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
1207   const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
1208   const SCEV *IdxDiffSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA);
1209   if (IdxDiffSCEV == SE.getCouldNotCompute())
1210     return std::nullopt;
1211 
1212   ConstantRange IdxDiffRange = SE.getSignedRange(IdxDiffSCEV);
1213   if (!IdxDiffRange.isSingleElement())
1214     return std::nullopt;
1215   APInt IdxDiff = *IdxDiffRange.getSingleElement();
1216 
1217   LLVM_DEBUG(dbgs() << "LSV: gtConstantOffsetComplexAddrs IdxDiff=" << IdxDiff
1218                     << "\n");
1219 
1220   // Now we need to prove that adding IdxDiff to ValA won't overflow.
1221   bool Safe = false;
1222 
1223   // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
1224   // ValA, we're okay.
1225   if (OpB->getOpcode() == Instruction::Add &&
1226       isa<ConstantInt>(OpB->getOperand(1)) &&
1227       IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
1228       checkNoWrapFlags(OpB, Signed))
1229     Safe = true;
1230 
1231   // Second attempt: check if we have eligible add NSW/NUW instruction
1232   // sequences.
1233   OpA = dyn_cast<Instruction>(ValA);
1234   if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
1235       OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) &&
1236       checkNoWrapFlags(OpB, Signed)) {
1237     // In the checks below a matching operand in OpA and OpB is an operand which
1238     // is the same in those two instructions.  Below we account for possible
1239     // orders of the operands of these add instructions.
1240     for (unsigned MatchingOpIdxA : {0, 1})
1241       for (unsigned MatchingOpIdxB : {0, 1})
1242         if (!Safe)
1243           Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOpIdxA, OpB,
1244                                         MatchingOpIdxB, Signed);
1245   }
1246 
1247   unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
1248 
1249   // Third attempt:
1250   //
1251   // Assuming IdxDiff is positive: If all set bits of IdxDiff or any higher
1252   // order bit other than the sign bit are known to be zero in ValA, we can add
1253   // Diff to it while guaranteeing no overflow of any sort.
1254   //
1255   // If IdxDiff is negative, do the same, but swap ValA and ValB.
1256   if (!Safe) {
1257     // When computing known bits, use the GEPs as context instructions, since
1258     // they likely are in the same BB as the load/store.
1259     Instruction *ContextInst = GEPA->comesBefore(GEPB) ? GEPB : GEPA;
1260     KnownBits Known(BitWidth);
1261     computeKnownBits((IdxDiff.sge(0) ? ValA : OpB), Known, DL, 0, &AC,
1262                      ContextInst, &DT);
1263     APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth());
1264     if (Signed)
1265       BitsAllowedToBeSet.clearBit(BitWidth - 1);
1266     if (BitsAllowedToBeSet.ult(IdxDiff.abs()))
1267       return std::nullopt;
1268     Safe = true;
1269   }
1270 
1271   if (Safe)
1272     return IdxDiff * Stride;
1273   return std::nullopt;
1274 }
1275 
1276 std::optional<APInt>
1277 Vectorizer::getConstantOffsetSelects(Value *PtrA, Value *PtrB, unsigned Depth) {
1278   if (Depth++ == MaxDepth)
1279     return std::nullopt;
1280 
1281   if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
1282     if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
1283       if (SelectA->getCondition() != SelectB->getCondition())
1284         return std::nullopt;
1285       LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetSelects, PtrA=" << *PtrA
1286                         << ", PtrB=" << *PtrB << ", Depth=" << Depth << "\n");
1287       std::optional<APInt> TrueDiff = getConstantOffset(
1288           SelectA->getTrueValue(), SelectB->getTrueValue(), Depth);
1289       if (!TrueDiff.has_value())
1290         return std::nullopt;
1291       std::optional<APInt> FalseDiff = getConstantOffset(
1292           SelectA->getFalseValue(), SelectB->getFalseValue(), Depth);
1293       if (TrueDiff == FalseDiff)
1294         return TrueDiff;
1295     }
1296   }
1297   return std::nullopt;
1298 }
1299 
1300 EquivalenceClassMap
1301 Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
1302                                       BasicBlock::iterator End) {
1303   EquivalenceClassMap Ret;
1304 
1305   auto getUnderlyingObject = [](const Value *Ptr) -> const Value * {
1306     const Value *ObjPtr = llvm::getUnderlyingObject(Ptr);
1307     if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
1308       // The select's themselves are distinct instructions even if they share
1309       // the same condition and evaluate to consecutive pointers for true and
1310       // false values of the condition. Therefore using the select's themselves
1311       // for grouping instructions would put consecutive accesses into different
1312       // lists and they won't be even checked for being consecutive, and won't
1313       // be vectorized.
1314       return Sel->getCondition();
1315     }
1316     return ObjPtr;
1317   };
1318 
1319   for (Instruction &I : make_range(Begin, End)) {
1320     auto *LI = dyn_cast<LoadInst>(&I);
1321     auto *SI = dyn_cast<StoreInst>(&I);
1322     if (!LI && !SI)
1323       continue;
1324 
1325     if ((LI && !LI->isSimple()) || (SI && !SI->isSimple()))
1326       continue;
1327 
1328     if ((LI && !TTI.isLegalToVectorizeLoad(LI)) ||
1329         (SI && !TTI.isLegalToVectorizeStore(SI)))
1330       continue;
1331 
1332     Type *Ty = getLoadStoreType(&I);
1333     if (!VectorType::isValidElementType(Ty->getScalarType()))
1334       continue;
1335 
1336     // Skip weird non-byte sizes. They probably aren't worth the effort of
1337     // handling correctly.
1338     unsigned TySize = DL.getTypeSizeInBits(Ty);
1339     if ((TySize % 8) != 0)
1340       continue;
1341 
1342     // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
1343     // functions are currently using an integer type for the vectorized
1344     // load/store, and does not support casting between the integer type and a
1345     // vector of pointers (e.g. i64 to <2 x i16*>)
1346     if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
1347       continue;
1348 
1349     Value *Ptr = getLoadStorePointerOperand(&I);
1350     unsigned AS = Ptr->getType()->getPointerAddressSpace();
1351     unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1352 
1353     unsigned VF = VecRegSize / TySize;
1354     VectorType *VecTy = dyn_cast<VectorType>(Ty);
1355 
1356     // Only handle power-of-two sized elements.
1357     if ((!VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(Ty))) ||
1358         (VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(VecTy->getScalarType()))))
1359       continue;
1360 
1361     // No point in looking at these if they're too big to vectorize.
1362     if (TySize > VecRegSize / 2 ||
1363         (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
1364       continue;
1365 
1366     Ret[{getUnderlyingObject(Ptr), AS,
1367          DL.getTypeSizeInBits(getLoadStoreType(&I)->getScalarType()),
1368          /*IsLoad=*/LI != nullptr}]
1369         .push_back(&I);
1370   }
1371 
1372   return Ret;
1373 }
1374 
1375 std::vector<Chain> Vectorizer::gatherChains(ArrayRef<Instruction *> Instrs) {
1376   if (Instrs.empty())
1377     return {};
1378 
1379   unsigned AS = getLoadStoreAddressSpace(Instrs[0]);
1380   unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
1381 
1382 #ifndef NDEBUG
1383   // Check that Instrs is in BB order and all have the same addr space.
1384   for (size_t I = 1; I < Instrs.size(); ++I) {
1385     assert(Instrs[I - 1]->comesBefore(Instrs[I]));
1386     assert(getLoadStoreAddressSpace(Instrs[I]) == AS);
1387   }
1388 #endif
1389 
1390   // Machinery to build an MRU-hashtable of Chains.
1391   //
1392   // (Ideally this could be done with MapVector, but as currently implemented,
1393   // moving an element to the front of a MapVector is O(n).)
1394   struct InstrListElem : ilist_node<InstrListElem>,
1395                          std::pair<Instruction *, Chain> {
1396     explicit InstrListElem(Instruction *I)
1397         : std::pair<Instruction *, Chain>(I, {}) {}
1398   };
1399   struct InstrListElemDenseMapInfo {
1400     using PtrInfo = DenseMapInfo<InstrListElem *>;
1401     using IInfo = DenseMapInfo<Instruction *>;
1402     static InstrListElem *getEmptyKey() { return PtrInfo::getEmptyKey(); }
1403     static InstrListElem *getTombstoneKey() {
1404       return PtrInfo::getTombstoneKey();
1405     }
1406     static unsigned getHashValue(const InstrListElem *E) {
1407       return IInfo::getHashValue(E->first);
1408     }
1409     static bool isEqual(const InstrListElem *A, const InstrListElem *B) {
1410       if (A == getEmptyKey() || B == getEmptyKey())
1411         return A == getEmptyKey() && B == getEmptyKey();
1412       if (A == getTombstoneKey() || B == getTombstoneKey())
1413         return A == getTombstoneKey() && B == getTombstoneKey();
1414       return IInfo::isEqual(A->first, B->first);
1415     }
1416   };
1417   SpecificBumpPtrAllocator<InstrListElem> Allocator;
1418   simple_ilist<InstrListElem> MRU;
1419   DenseSet<InstrListElem *, InstrListElemDenseMapInfo> Chains;
1420 
1421   // Compare each instruction in `instrs` to leader of the N most recently-used
1422   // chains.  This limits the O(n^2) behavior of this pass while also allowing
1423   // us to build arbitrarily long chains.
1424   for (Instruction *I : Instrs) {
1425     constexpr size_t MaxChainsToTry = 64;
1426 
1427     bool MatchFound = false;
1428     auto ChainIter = MRU.begin();
1429     for (int J = 0; J < MaxChainsToTry && ChainIter != MRU.end();
1430          ++J, ++ChainIter) {
1431       std::optional<APInt> Offset =
1432           getConstantOffset(getLoadStorePointerOperand(ChainIter->first),
1433                             getLoadStorePointerOperand(I));
1434       if (Offset.has_value()) {
1435         // `Offset` might not have the expected number of bits, if e.g. AS has a
1436         // different number of bits than opaque pointers.
1437         ChainIter->second.push_back(
1438             ChainElem{I, Offset.value().sextOrTrunc(ASPtrBits)});
1439         // Move ChainIter to the front of the MRU list.
1440         MRU.remove(*ChainIter);
1441         MRU.push_front(*ChainIter);
1442         MatchFound = true;
1443         break;
1444       }
1445     }
1446 
1447     if (!MatchFound) {
1448       APInt ZeroOffset(ASPtrBits, 0);
1449       InstrListElem *E = new (Allocator.Allocate()) InstrListElem(I);
1450       E->second.push_back(ChainElem{I, ZeroOffset});
1451       MRU.push_front(*E);
1452       Chains.insert(E);
1453     }
1454   }
1455 
1456   std::vector<Chain> Ret;
1457   Ret.reserve(Chains.size());
1458   // Iterate over MRU rather than Chains so the order is deterministic.
1459   for (auto &E : MRU)
1460     if (E.second.size() > 1)
1461       Ret.push_back(std::move(E.second));
1462   return Ret;
1463 }
1464 
1465 std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
1466                                                    unsigned Depth) {
1467   LLVM_DEBUG(dbgs() << "LSV: getConstantOffset, PtrA=" << *PtrA
1468                     << ", PtrB=" << *PtrB << ", Depth=" << Depth << "\n");
1469   unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(PtrA->getType());
1470   APInt OffsetA(OffsetBitWidth, 0);
1471   APInt OffsetB(OffsetBitWidth, 0);
1472   PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
1473   PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
1474   unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
1475   if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
1476     return std::nullopt;
1477 
1478   // If we have to shrink the pointer, stripAndAccumulateInBoundsConstantOffsets
1479   // should properly handle a possible overflow and the value should fit into
1480   // the smallest data type used in the cast/gep chain.
1481   assert(OffsetA.getSignificantBits() <= NewPtrBitWidth &&
1482          OffsetB.getSignificantBits() <= NewPtrBitWidth);
1483 
1484   OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
1485   OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
1486   if (PtrA == PtrB)
1487     return OffsetB - OffsetA;
1488 
1489   // Try to compute B - A.
1490   const SCEV *DistScev = SE.getMinusSCEV(SE.getSCEV(PtrB), SE.getSCEV(PtrA));
1491   if (DistScev != SE.getCouldNotCompute()) {
1492     LLVM_DEBUG(dbgs() << "LSV: SCEV PtrB - PtrA =" << *DistScev << "\n");
1493     ConstantRange DistRange = SE.getSignedRange(DistScev);
1494     if (DistRange.isSingleElement())
1495       return OffsetB - OffsetA + *DistRange.getSingleElement();
1496   }
1497   std::optional<APInt> Diff = gtConstantOffsetComplexAddrs(PtrA, PtrB, Depth);
1498   if (Diff.has_value())
1499     return OffsetB - OffsetA + Diff->sext(OffsetB.getBitWidth());
1500   return std::nullopt;
1501 }
1502