xref: /llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp (revision 8e702735090388a3231a863e343f880d0f96fecb)
1 //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass merges loads/stores to/from sequential memory addresses into vector
10 // loads/stores.  Although there's nothing GPU-specific in here, this pass is
11 // motivated by the microarchitectural quirks of nVidia and AMD GPUs.
12 //
13 // (For simplicity below we talk about loads only, but everything also applies
14 // to stores.)
15 //
16 // This pass is intended to be run late in the pipeline, after other
17 // vectorization opportunities have been exploited.  So the assumption here is
18 // that immediately following our new vector load we'll need to extract out the
19 // individual elements of the load, so we can operate on them individually.
20 //
21 // On CPUs this transformation is usually not beneficial, because extracting the
22 // elements of a vector register is expensive on most architectures.  It's
23 // usually better just to load each element individually into its own scalar
24 // register.
25 //
26 // However, nVidia and AMD GPUs don't have proper vector registers.  Instead, a
27 // "vector load" loads directly into a series of scalar registers.  In effect,
28 // extracting the elements of the vector is free.  It's therefore always
29 // beneficial to vectorize a sequence of loads on these architectures.
30 //
31 // Vectorizing (perhaps a better name might be "coalescing") loads can have
32 // large performance impacts on GPU kernels, and opportunities for vectorizing
33 // are common in GPU code.  This pass tries very hard to find such
34 // opportunities; its runtime is quadratic in the number of loads in a BB.
35 //
36 // Some CPU architectures, such as ARM, have instructions that load into
37 // multiple scalar registers, similar to a GPU vectorized load.  In theory ARM
38 // could use this pass (with some modifications), but currently it implements
39 // its own pass to do something similar to what we do here.
40 //
41 // Overview of the algorithm and terminology in this pass:
42 //
43 //  - Break up each basic block into pseudo-BBs, composed of instructions which
44 //    are guaranteed to transfer control to their successors.
45 //  - Within a single pseudo-BB, find all loads, and group them into
46 //    "equivalence classes" according to getUnderlyingObject() and loaded
47 //    element size.  Do the same for stores.
48 //  - For each equivalence class, greedily build "chains".  Each chain has a
49 //    leader instruction, and every other member of the chain has a known
50 //    constant offset from the first instr in the chain.
51 //  - Break up chains so that they contain only contiguous accesses of legal
52 //    size with no intervening may-alias instrs.
53 //  - Convert each chain to vector instructions.
54 //
55 // The O(n^2) behavior of this pass comes from initially building the chains.
56 // In the worst case we have to compare each new instruction to all of those
57 // that came before. To limit this, we only calculate the offset to the leaders
58 // of the N most recently-used chains.
59 
60 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
64 #include "llvm/ADT/MapVector.h"
65 #include "llvm/ADT/PostOrderIterator.h"
66 #include "llvm/ADT/STLExtras.h"
67 #include "llvm/ADT/Sequence.h"
68 #include "llvm/ADT/SmallPtrSet.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "llvm/ADT/Statistic.h"
71 #include "llvm/ADT/iterator_range.h"
72 #include "llvm/Analysis/AliasAnalysis.h"
73 #include "llvm/Analysis/AssumptionCache.h"
74 #include "llvm/Analysis/MemoryLocation.h"
75 #include "llvm/Analysis/ScalarEvolution.h"
76 #include "llvm/Analysis/TargetTransformInfo.h"
77 #include "llvm/Analysis/ValueTracking.h"
78 #include "llvm/Analysis/VectorUtils.h"
79 #include "llvm/IR/Attributes.h"
80 #include "llvm/IR/BasicBlock.h"
81 #include "llvm/IR/ConstantRange.h"
82 #include "llvm/IR/Constants.h"
83 #include "llvm/IR/DataLayout.h"
84 #include "llvm/IR/DerivedTypes.h"
85 #include "llvm/IR/Dominators.h"
86 #include "llvm/IR/Function.h"
87 #include "llvm/IR/GetElementPtrTypeIterator.h"
88 #include "llvm/IR/IRBuilder.h"
89 #include "llvm/IR/InstrTypes.h"
90 #include "llvm/IR/Instruction.h"
91 #include "llvm/IR/Instructions.h"
92 #include "llvm/IR/LLVMContext.h"
93 #include "llvm/IR/Module.h"
94 #include "llvm/IR/Type.h"
95 #include "llvm/IR/Value.h"
96 #include "llvm/InitializePasses.h"
97 #include "llvm/Pass.h"
98 #include "llvm/Support/Alignment.h"
99 #include "llvm/Support/Casting.h"
100 #include "llvm/Support/Debug.h"
101 #include "llvm/Support/KnownBits.h"
102 #include "llvm/Support/MathExtras.h"
103 #include "llvm/Support/ModRef.h"
104 #include "llvm/Support/raw_ostream.h"
105 #include "llvm/Transforms/Utils/Local.h"
106 #include <algorithm>
107 #include <cassert>
108 #include <cstdint>
109 #include <cstdlib>
110 #include <iterator>
111 #include <numeric>
112 #include <optional>
113 #include <tuple>
114 #include <type_traits>
115 #include <utility>
116 #include <vector>
117 
118 using namespace llvm;
119 
120 #define DEBUG_TYPE "load-store-vectorizer"
121 
122 STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
123 STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
124 
125 namespace {
126 
127 // Equivalence class key, the initial tuple by which we group loads/stores.
128 // Loads/stores with different EqClassKeys are never merged.
129 //
130 // (We could in theory remove element-size from the this tuple.  We'd just need
131 // to fix up the vector packing/unpacking code.)
132 using EqClassKey =
133     std::tuple<const Value * /* result of getUnderlyingObject() */,
134                unsigned /* AddrSpace */,
135                unsigned /* Load/Store element size bits */,
136                char /* IsLoad; char b/c bool can't be a DenseMap key */
137                >;
138 [[maybe_unused]] llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
139                                                const EqClassKey &K) {
140   const auto &[UnderlyingObject, AddrSpace, ElementSize, IsLoad] = K;
141   return OS << (IsLoad ? "load" : "store") << " of " << *UnderlyingObject
142             << " of element size " << ElementSize << " bits in addrspace "
143             << AddrSpace;
144 }
145 
146 // A Chain is a set of instructions such that:
147 //  - All instructions have the same equivalence class, so in particular all are
148 //    loads, or all are stores.
149 //  - We know the address accessed by the i'th chain elem relative to the
150 //    chain's leader instruction, which is the first instr of the chain in BB
151 //    order.
152 //
153 // Chains have two canonical orderings:
154 //  - BB order, sorted by Instr->comesBefore.
155 //  - Offset order, sorted by OffsetFromLeader.
156 // This pass switches back and forth between these orders.
157 struct ChainElem {
158   Instruction *Inst;
159   APInt OffsetFromLeader;
160   ChainElem(Instruction *Inst, APInt OffsetFromLeader)
161       : Inst(std::move(Inst)), OffsetFromLeader(std::move(OffsetFromLeader)) {}
162 };
163 using Chain = SmallVector<ChainElem, 1>;
164 
165 void sortChainInBBOrder(Chain &C) {
166   sort(C, [](auto &A, auto &B) { return A.Inst->comesBefore(B.Inst); });
167 }
168 
169 void sortChainInOffsetOrder(Chain &C) {
170   sort(C, [](const auto &A, const auto &B) {
171     if (A.OffsetFromLeader != B.OffsetFromLeader)
172       return A.OffsetFromLeader.slt(B.OffsetFromLeader);
173     return A.Inst->comesBefore(B.Inst); // stable tiebreaker
174   });
175 }
176 
177 [[maybe_unused]] void dumpChain(ArrayRef<ChainElem> C) {
178   for (const auto &E : C) {
179     dbgs() << "  " << *E.Inst << " (offset " << E.OffsetFromLeader << ")\n";
180   }
181 }
182 
183 using EquivalenceClassMap =
184     MapVector<EqClassKey, SmallVector<Instruction *, 8>>;
185 
186 // FIXME: Assuming stack alignment of 4 is always good enough
187 constexpr unsigned StackAdjustedAlignment = 4;
188 
189 Instruction *propagateMetadata(Instruction *I, const Chain &C) {
190   SmallVector<Value *, 8> Values;
191   for (const ChainElem &E : C)
192     Values.emplace_back(E.Inst);
193   return propagateMetadata(I, Values);
194 }
195 
196 bool isInvariantLoad(const Instruction *I) {
197   const LoadInst *LI = dyn_cast<LoadInst>(I);
198   return LI != nullptr && LI->hasMetadata(LLVMContext::MD_invariant_load);
199 }
200 
201 /// Reorders the instructions that I depends on (the instructions defining its
202 /// operands), to ensure they dominate I.
203 void reorder(Instruction *I) {
204   SmallPtrSet<Instruction *, 16> InstructionsToMove;
205   SmallVector<Instruction *, 16> Worklist;
206 
207   Worklist.emplace_back(I);
208   while (!Worklist.empty()) {
209     Instruction *IW = Worklist.pop_back_val();
210     int NumOperands = IW->getNumOperands();
211     for (int Idx = 0; Idx < NumOperands; Idx++) {
212       Instruction *IM = dyn_cast<Instruction>(IW->getOperand(Idx));
213       if (!IM || IM->getOpcode() == Instruction::PHI)
214         continue;
215 
216       // If IM is in another BB, no need to move it, because this pass only
217       // vectorizes instructions within one BB.
218       if (IM->getParent() != I->getParent())
219         continue;
220 
221       assert(IM != I && "Unexpected cycle while re-ordering instructions");
222 
223       if (!IM->comesBefore(I)) {
224         InstructionsToMove.insert(IM);
225         Worklist.emplace_back(IM);
226       }
227     }
228   }
229 
230   // All instructions to move should follow I. Start from I, not from begin().
231   for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;) {
232     Instruction *IM = &*(BBI++);
233     if (!InstructionsToMove.contains(IM))
234       continue;
235     IM->moveBefore(I->getIterator());
236   }
237 }
238 
239 class Vectorizer {
240   Function &F;
241   AliasAnalysis &AA;
242   AssumptionCache &AC;
243   DominatorTree &DT;
244   ScalarEvolution &SE;
245   TargetTransformInfo &TTI;
246   const DataLayout &DL;
247   IRBuilder<> Builder;
248 
249   // We could erase instrs right after vectorizing them, but that can mess up
250   // our BB iterators, and also can make the equivalence class keys point to
251   // freed memory.  This is fixable, but it's simpler just to wait until we're
252   // done with the BB and erase all at once.
253   SmallVector<Instruction *, 128> ToErase;
254 
255 public:
256   Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
257              DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
258       : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI),
259         DL(F.getDataLayout()), Builder(SE.getContext()) {}
260 
261   bool run();
262 
263 private:
264   static const unsigned MaxDepth = 3;
265 
266   /// Runs the vectorizer on a "pseudo basic block", which is a range of
267   /// instructions [Begin, End) within one BB all of which have
268   /// isGuaranteedToTransferExecutionToSuccessor(I) == true.
269   bool runOnPseudoBB(BasicBlock::iterator Begin, BasicBlock::iterator End);
270 
271   /// Runs the vectorizer on one equivalence class, i.e. one set of loads/stores
272   /// in the same BB with the same value for getUnderlyingObject() etc.
273   bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
274                              ArrayRef<Instruction *> EqClass);
275 
276   /// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
277   /// where all instructions access a known, constant offset from the first
278   /// instruction.
279   bool runOnChain(Chain &C);
280 
281   /// Splits the chain into subchains of instructions which read/write a
282   /// contiguous block of memory.  Discards any length-1 subchains (because
283   /// there's nothing to vectorize in there).
284   std::vector<Chain> splitChainByContiguity(Chain &C);
285 
286   /// Splits the chain into subchains where it's safe to hoist loads up to the
287   /// beginning of the sub-chain and it's safe to sink loads up to the end of
288   /// the sub-chain.  Discards any length-1 subchains.
289   std::vector<Chain> splitChainByMayAliasInstrs(Chain &C);
290 
291   /// Splits the chain into subchains that make legal, aligned accesses.
292   /// Discards any length-1 subchains.
293   std::vector<Chain> splitChainByAlignment(Chain &C);
294 
295   /// Converts the instrs in the chain into a single vectorized load or store.
296   /// Adds the old scalar loads/stores to ToErase.
297   bool vectorizeChain(Chain &C);
298 
299   /// Tries to compute the offset in bytes PtrB - PtrA.
300   std::optional<APInt> getConstantOffset(Value *PtrA, Value *PtrB,
301                                          Instruction *ContextInst,
302                                          unsigned Depth = 0);
303   std::optional<APInt> getConstantOffsetComplexAddrs(Value *PtrA, Value *PtrB,
304                                                      Instruction *ContextInst,
305                                                      unsigned Depth);
306   std::optional<APInt> getConstantOffsetSelects(Value *PtrA, Value *PtrB,
307                                                 Instruction *ContextInst,
308                                                 unsigned Depth);
309 
310   /// Gets the element type of the vector that the chain will load or store.
311   /// This is nontrivial because the chain may contain elements of different
312   /// types; e.g. it's legal to have a chain that contains both i32 and float.
313   Type *getChainElemTy(const Chain &C);
314 
315   /// Determines whether ChainElem can be moved up (if IsLoad) or down (if
316   /// !IsLoad) to ChainBegin -- i.e. there are no intervening may-alias
317   /// instructions.
318   ///
319   /// The map ChainElemOffsets must contain all of the elements in
320   /// [ChainBegin, ChainElem] and their offsets from some arbitrary base
321   /// address.  It's ok if it contains additional entries.
322   template <bool IsLoadChain>
323   bool isSafeToMove(
324       Instruction *ChainElem, Instruction *ChainBegin,
325       const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
326 
327   /// Merges the equivalence classes if they have underlying objects that differ
328   /// by one level of indirection (i.e., one is a getelementptr and the other is
329   /// the base pointer in that getelementptr).
330   void mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const;
331 
332   /// Collects loads and stores grouped by "equivalence class", where:
333   ///   - all elements in an eq class are a load or all are a store,
334   ///   - they all load/store the same element size (it's OK to have e.g. i8 and
335   ///     <4 x i8> in the same class, but not i32 and <4 x i8>), and
336   ///   - they all have the same value for getUnderlyingObject().
337   EquivalenceClassMap collectEquivalenceClasses(BasicBlock::iterator Begin,
338                                                 BasicBlock::iterator End);
339 
340   /// Partitions Instrs into "chains" where every instruction has a known
341   /// constant offset from the first instr in the chain.
342   ///
343   /// Postcondition: For all i, ret[i][0].second == 0, because the first instr
344   /// in the chain is the leader, and an instr touches distance 0 from itself.
345   std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs);
346 };
347 
348 class LoadStoreVectorizerLegacyPass : public FunctionPass {
349 public:
350   static char ID;
351 
352   LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {
353     initializeLoadStoreVectorizerLegacyPassPass(
354         *PassRegistry::getPassRegistry());
355   }
356 
357   bool runOnFunction(Function &F) override;
358 
359   StringRef getPassName() const override {
360     return "GPU Load and Store Vectorizer";
361   }
362 
363   void getAnalysisUsage(AnalysisUsage &AU) const override {
364     AU.addRequired<AAResultsWrapperPass>();
365     AU.addRequired<AssumptionCacheTracker>();
366     AU.addRequired<ScalarEvolutionWrapperPass>();
367     AU.addRequired<DominatorTreeWrapperPass>();
368     AU.addRequired<TargetTransformInfoWrapperPass>();
369     AU.setPreservesCFG();
370   }
371 };
372 
373 } // end anonymous namespace
374 
375 char LoadStoreVectorizerLegacyPass::ID = 0;
376 
377 INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
378                       "Vectorize load and Store instructions", false, false)
379 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
380 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker);
381 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
382 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
383 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
384 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
385 INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
386                     "Vectorize load and store instructions", false, false)
387 
388 Pass *llvm::createLoadStoreVectorizerPass() {
389   return new LoadStoreVectorizerLegacyPass();
390 }
391 
392 bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
393   // Don't vectorize when the attribute NoImplicitFloat is used.
394   if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat))
395     return false;
396 
397   AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
398   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
399   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
400   TargetTransformInfo &TTI =
401       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
402 
403   AssumptionCache &AC =
404       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
405 
406   return Vectorizer(F, AA, AC, DT, SE, TTI).run();
407 }
408 
409 PreservedAnalyses LoadStoreVectorizerPass::run(Function &F,
410                                                FunctionAnalysisManager &AM) {
411   // Don't vectorize when the attribute NoImplicitFloat is used.
412   if (F.hasFnAttribute(Attribute::NoImplicitFloat))
413     return PreservedAnalyses::all();
414 
415   AliasAnalysis &AA = AM.getResult<AAManager>(F);
416   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
417   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
418   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
419   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
420 
421   bool Changed = Vectorizer(F, AA, AC, DT, SE, TTI).run();
422   PreservedAnalyses PA;
423   PA.preserveSet<CFGAnalyses>();
424   return Changed ? PA : PreservedAnalyses::all();
425 }
426 
427 bool Vectorizer::run() {
428   bool Changed = false;
429   // Break up the BB if there are any instrs which aren't guaranteed to transfer
430   // execution to their successor.
431   //
432   // Consider, for example:
433   //
434   //   def assert_arr_len(int n) { if (n < 2) exit(); }
435   //
436   //   load arr[0]
437   //   call assert_array_len(arr.length)
438   //   load arr[1]
439   //
440   // Even though assert_arr_len does not read or write any memory, we can't
441   // speculate the second load before the call.  More info at
442   // https://github.com/llvm/llvm-project/issues/52950.
443   for (BasicBlock *BB : post_order(&F)) {
444     // BB must at least have a terminator.
445     assert(!BB->empty());
446 
447     SmallVector<BasicBlock::iterator, 8> Barriers;
448     Barriers.emplace_back(BB->begin());
449     for (Instruction &I : *BB)
450       if (!isGuaranteedToTransferExecutionToSuccessor(&I))
451         Barriers.emplace_back(I.getIterator());
452     Barriers.emplace_back(BB->end());
453 
454     for (auto It = Barriers.begin(), End = std::prev(Barriers.end()); It != End;
455          ++It)
456       Changed |= runOnPseudoBB(*It, *std::next(It));
457 
458     for (Instruction *I : ToErase) {
459       auto *PtrOperand = getLoadStorePointerOperand(I);
460       if (I->use_empty())
461         I->eraseFromParent();
462       RecursivelyDeleteTriviallyDeadInstructions(PtrOperand);
463     }
464     ToErase.clear();
465   }
466 
467   return Changed;
468 }
469 
470 bool Vectorizer::runOnPseudoBB(BasicBlock::iterator Begin,
471                                BasicBlock::iterator End) {
472   LLVM_DEBUG({
473     dbgs() << "LSV: Running on pseudo-BB [" << *Begin << " ... ";
474     if (End != Begin->getParent()->end())
475       dbgs() << *End;
476     else
477       dbgs() << "<BB end>";
478     dbgs() << ")\n";
479   });
480 
481   bool Changed = false;
482   for (const auto &[EqClassKey, EqClass] :
483        collectEquivalenceClasses(Begin, End))
484     Changed |= runOnEquivalenceClass(EqClassKey, EqClass);
485 
486   return Changed;
487 }
488 
489 bool Vectorizer::runOnEquivalenceClass(const EqClassKey &EqClassKey,
490                                        ArrayRef<Instruction *> EqClass) {
491   bool Changed = false;
492 
493   LLVM_DEBUG({
494     dbgs() << "LSV: Running on equivalence class of size " << EqClass.size()
495            << " keyed on " << EqClassKey << ":\n";
496     for (Instruction *I : EqClass)
497       dbgs() << "  " << *I << "\n";
498   });
499 
500   std::vector<Chain> Chains = gatherChains(EqClass);
501   LLVM_DEBUG(dbgs() << "LSV: Got " << Chains.size()
502                     << " nontrivial chains.\n";);
503   for (Chain &C : Chains)
504     Changed |= runOnChain(C);
505   return Changed;
506 }
507 
508 bool Vectorizer::runOnChain(Chain &C) {
509   LLVM_DEBUG({
510     dbgs() << "LSV: Running on chain with " << C.size() << " instructions:\n";
511     dumpChain(C);
512   });
513 
514   // Split up the chain into increasingly smaller chains, until we can finally
515   // vectorize the chains.
516   //
517   // (Don't be scared by the depth of the loop nest here.  These operations are
518   // all at worst O(n lg n) in the number of instructions, and splitting chains
519   // doesn't change the number of instrs.  So the whole loop nest is O(n lg n).)
520   bool Changed = false;
521   for (auto &C : splitChainByMayAliasInstrs(C))
522     for (auto &C : splitChainByContiguity(C))
523       for (auto &C : splitChainByAlignment(C))
524         Changed |= vectorizeChain(C);
525   return Changed;
526 }
527 
528 std::vector<Chain> Vectorizer::splitChainByMayAliasInstrs(Chain &C) {
529   if (C.empty())
530     return {};
531 
532   sortChainInBBOrder(C);
533 
534   LLVM_DEBUG({
535     dbgs() << "LSV: splitChainByMayAliasInstrs considering chain:\n";
536     dumpChain(C);
537   });
538 
539   // We know that elements in the chain with nonverlapping offsets can't
540   // alias, but AA may not be smart enough to figure this out.  Use a
541   // hashtable.
542   DenseMap<Instruction *, APInt /*OffsetFromLeader*/> ChainOffsets;
543   for (const auto &E : C)
544     ChainOffsets.insert({&*E.Inst, E.OffsetFromLeader});
545 
546   // Loads get hoisted up to the first load in the chain.  Stores get sunk
547   // down to the last store in the chain.  Our algorithm for loads is:
548   //
549   //  - Take the first element of the chain.  This is the start of a new chain.
550   //  - Take the next element of `Chain` and check for may-alias instructions
551   //    up to the start of NewChain.  If no may-alias instrs, add it to
552   //    NewChain.  Otherwise, start a new NewChain.
553   //
554   // For stores it's the same except in the reverse direction.
555   //
556   // We expect IsLoad to be an std::bool_constant.
557   auto Impl = [&](auto IsLoad) {
558     // MSVC is unhappy if IsLoad is a capture, so pass it as an arg.
559     auto [ChainBegin, ChainEnd] = [&](auto IsLoad) {
560       if constexpr (IsLoad())
561         return std::make_pair(C.begin(), C.end());
562       else
563         return std::make_pair(C.rbegin(), C.rend());
564     }(IsLoad);
565     assert(ChainBegin != ChainEnd);
566 
567     std::vector<Chain> Chains;
568     SmallVector<ChainElem, 1> NewChain;
569     NewChain.emplace_back(*ChainBegin);
570     for (auto ChainIt = std::next(ChainBegin); ChainIt != ChainEnd; ++ChainIt) {
571       if (isSafeToMove<IsLoad>(ChainIt->Inst, NewChain.front().Inst,
572                                ChainOffsets)) {
573         LLVM_DEBUG(dbgs() << "LSV: No intervening may-alias instrs; can merge "
574                           << *ChainIt->Inst << " into " << *ChainBegin->Inst
575                           << "\n");
576         NewChain.emplace_back(*ChainIt);
577       } else {
578         LLVM_DEBUG(
579             dbgs() << "LSV: Found intervening may-alias instrs; cannot merge "
580                    << *ChainIt->Inst << " into " << *ChainBegin->Inst << "\n");
581         if (NewChain.size() > 1) {
582           LLVM_DEBUG({
583             dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
584             dumpChain(NewChain);
585           });
586           Chains.emplace_back(std::move(NewChain));
587         }
588 
589         // Start a new chain.
590         NewChain = SmallVector<ChainElem, 1>({*ChainIt});
591       }
592     }
593     if (NewChain.size() > 1) {
594       LLVM_DEBUG({
595         dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
596         dumpChain(NewChain);
597       });
598       Chains.emplace_back(std::move(NewChain));
599     }
600     return Chains;
601   };
602 
603   if (isa<LoadInst>(C[0].Inst))
604     return Impl(/*IsLoad=*/std::bool_constant<true>());
605 
606   assert(isa<StoreInst>(C[0].Inst));
607   return Impl(/*IsLoad=*/std::bool_constant<false>());
608 }
609 
610 std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
611   if (C.empty())
612     return {};
613 
614   sortChainInOffsetOrder(C);
615 
616   LLVM_DEBUG({
617     dbgs() << "LSV: splitChainByContiguity considering chain:\n";
618     dumpChain(C);
619   });
620 
621   std::vector<Chain> Ret;
622   Ret.push_back({C.front()});
623 
624   for (auto It = std::next(C.begin()), End = C.end(); It != End; ++It) {
625     // `prev` accesses offsets [PrevDistFromBase, PrevReadEnd).
626     auto &CurChain = Ret.back();
627     const ChainElem &Prev = CurChain.back();
628     unsigned SzBits = DL.getTypeSizeInBits(getLoadStoreType(&*Prev.Inst));
629     assert(SzBits % 8 == 0 && "Non-byte sizes should have been filtered out by "
630                               "collectEquivalenceClass");
631     APInt PrevReadEnd = Prev.OffsetFromLeader + SzBits / 8;
632 
633     // Add this instruction to the end of the current chain, or start a new one.
634     bool AreContiguous = It->OffsetFromLeader == PrevReadEnd;
635     LLVM_DEBUG(dbgs() << "LSV: Instructions are "
636                       << (AreContiguous ? "" : "not ") << "contiguous: "
637                       << *Prev.Inst << " (ends at offset " << PrevReadEnd
638                       << ") -> " << *It->Inst << " (starts at offset "
639                       << It->OffsetFromLeader << ")\n");
640     if (AreContiguous)
641       CurChain.push_back(*It);
642     else
643       Ret.push_back({*It});
644   }
645 
646   // Filter out length-1 chains, these are uninteresting.
647   llvm::erase_if(Ret, [](const auto &Chain) { return Chain.size() <= 1; });
648   return Ret;
649 }
650 
651 Type *Vectorizer::getChainElemTy(const Chain &C) {
652   assert(!C.empty());
653   // The rules are:
654   //  - If there are any pointer types in the chain, use an integer type.
655   //  - Prefer an integer type if it appears in the chain.
656   //  - Otherwise, use the first type in the chain.
657   //
658   // The rule about pointer types is a simplification when we merge e.g.  a load
659   // of a ptr and a double.  There's no direct conversion from a ptr to a
660   // double; it requires a ptrtoint followed by a bitcast.
661   //
662   // It's unclear to me if the other rules have any practical effect, but we do
663   // it to match this pass's previous behavior.
664   if (any_of(C, [](const ChainElem &E) {
665         return getLoadStoreType(E.Inst)->getScalarType()->isPointerTy();
666       })) {
667     return Type::getIntNTy(
668         F.getContext(),
669         DL.getTypeSizeInBits(getLoadStoreType(C[0].Inst)->getScalarType()));
670   }
671 
672   for (const ChainElem &E : C)
673     if (Type *T = getLoadStoreType(E.Inst)->getScalarType(); T->isIntegerTy())
674       return T;
675   return getLoadStoreType(C[0].Inst)->getScalarType();
676 }
677 
678 std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
679   // We use a simple greedy algorithm.
680   //  - Given a chain of length N, find all prefixes that
681   //    (a) are not longer than the max register length, and
682   //    (b) are a power of 2.
683   //  - Starting from the longest prefix, try to create a vector of that length.
684   //  - If one of them works, great.  Repeat the algorithm on any remaining
685   //    elements in the chain.
686   //  - If none of them work, discard the first element and repeat on a chain
687   //    of length N-1.
688   if (C.empty())
689     return {};
690 
691   sortChainInOffsetOrder(C);
692 
693   LLVM_DEBUG({
694     dbgs() << "LSV: splitChainByAlignment considering chain:\n";
695     dumpChain(C);
696   });
697 
698   bool IsLoadChain = isa<LoadInst>(C[0].Inst);
699   auto GetVectorFactor = [&](unsigned VF, unsigned LoadStoreSize,
700                              unsigned ChainSizeBytes, VectorType *VecTy) {
701     return IsLoadChain ? TTI.getLoadVectorFactor(VF, LoadStoreSize,
702                                                  ChainSizeBytes, VecTy)
703                        : TTI.getStoreVectorFactor(VF, LoadStoreSize,
704                                                   ChainSizeBytes, VecTy);
705   };
706 
707 #ifndef NDEBUG
708   for (const auto &E : C) {
709     Type *Ty = getLoadStoreType(E.Inst)->getScalarType();
710     assert(isPowerOf2_32(DL.getTypeSizeInBits(Ty)) &&
711            "Should have filtered out non-power-of-two elements in "
712            "collectEquivalenceClasses.");
713   }
714 #endif
715 
716   unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
717   unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AS) / 8;
718 
719   std::vector<Chain> Ret;
720   for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) {
721     // Find candidate chains of size not greater than the largest vector reg.
722     // These chains are over the closed interval [CBegin, CEnd].
723     SmallVector<std::pair<unsigned /*CEnd*/, unsigned /*SizeBytes*/>, 8>
724         CandidateChains;
725     for (unsigned CEnd = CBegin + 1, Size = C.size(); CEnd < Size; ++CEnd) {
726       APInt Sz = C[CEnd].OffsetFromLeader +
727                  DL.getTypeStoreSize(getLoadStoreType(C[CEnd].Inst)) -
728                  C[CBegin].OffsetFromLeader;
729       if (Sz.sgt(VecRegBytes))
730         break;
731       CandidateChains.emplace_back(CEnd,
732                                    static_cast<unsigned>(Sz.getLimitedValue()));
733     }
734 
735     // Consider the longest chain first.
736     for (auto It = CandidateChains.rbegin(), End = CandidateChains.rend();
737          It != End; ++It) {
738       auto [CEnd, SizeBytes] = *It;
739       LLVM_DEBUG(
740           dbgs() << "LSV: splitChainByAlignment considering candidate chain ["
741                  << *C[CBegin].Inst << " ... " << *C[CEnd].Inst << "]\n");
742 
743       Type *VecElemTy = getChainElemTy(C);
744       // Note, VecElemTy is a power of 2, but might be less than one byte.  For
745       // example, we can vectorize 2 x <2 x i4> to <4 x i4>, and in this case
746       // VecElemTy would be i4.
747       unsigned VecElemBits = DL.getTypeSizeInBits(VecElemTy);
748 
749       // SizeBytes and VecElemBits are powers of 2, so they divide evenly.
750       assert((8 * SizeBytes) % VecElemBits == 0);
751       unsigned NumVecElems = 8 * SizeBytes / VecElemBits;
752       FixedVectorType *VecTy = FixedVectorType::get(VecElemTy, NumVecElems);
753       unsigned VF = 8 * VecRegBytes / VecElemBits;
754 
755       // Check that TTI is happy with this vectorization factor.
756       unsigned TargetVF = GetVectorFactor(VF, VecElemBits,
757                                           VecElemBits * NumVecElems / 8, VecTy);
758       if (TargetVF != VF && TargetVF < NumVecElems) {
759         LLVM_DEBUG(
760             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
761                       "because TargetVF="
762                    << TargetVF << " != VF=" << VF
763                    << " and TargetVF < NumVecElems=" << NumVecElems << "\n");
764         continue;
765       }
766 
767       // Is a load/store with this alignment allowed by TTI and at least as fast
768       // as an unvectorized load/store?
769       //
770       // TTI and F are passed as explicit captures to WAR an MSVC misparse (??).
771       auto IsAllowedAndFast = [&, SizeBytes = SizeBytes, &TTI = TTI,
772                                &F = F](Align Alignment) {
773         if (Alignment.value() % SizeBytes == 0)
774           return true;
775         unsigned VectorizedSpeed = 0;
776         bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses(
777             F.getContext(), SizeBytes * 8, AS, Alignment, &VectorizedSpeed);
778         if (!AllowsMisaligned) {
779           LLVM_DEBUG(dbgs()
780                      << "LSV: Access of " << SizeBytes << "B in addrspace "
781                      << AS << " with alignment " << Alignment.value()
782                      << " is misaligned, and therefore can't be vectorized.\n");
783           return false;
784         }
785 
786         unsigned ElementwiseSpeed = 0;
787         (TTI).allowsMisalignedMemoryAccesses((F).getContext(), VecElemBits, AS,
788                                              Alignment, &ElementwiseSpeed);
789         if (VectorizedSpeed < ElementwiseSpeed) {
790           LLVM_DEBUG(dbgs()
791                      << "LSV: Access of " << SizeBytes << "B in addrspace "
792                      << AS << " with alignment " << Alignment.value()
793                      << " has relative speed " << VectorizedSpeed
794                      << ", which is lower than the elementwise speed of "
795                      << ElementwiseSpeed
796                      << ".  Therefore this access won't be vectorized.\n");
797           return false;
798         }
799         return true;
800       };
801 
802       // If we're loading/storing from an alloca, align it if possible.
803       //
804       // FIXME: We eagerly upgrade the alignment, regardless of whether TTI
805       // tells us this is beneficial.  This feels a bit odd, but it matches
806       // existing tests.  This isn't *so* bad, because at most we align to 4
807       // bytes (current value of StackAdjustedAlignment).
808       //
809       // FIXME: We will upgrade the alignment of the alloca even if it turns out
810       // we can't vectorize for some other reason.
811       Value *PtrOperand = getLoadStorePointerOperand(C[CBegin].Inst);
812       bool IsAllocaAccess = AS == DL.getAllocaAddrSpace() &&
813                             isa<AllocaInst>(PtrOperand->stripPointerCasts());
814       Align Alignment = getLoadStoreAlignment(C[CBegin].Inst);
815       Align PrefAlign = Align(StackAdjustedAlignment);
816       if (IsAllocaAccess && Alignment.value() % SizeBytes != 0 &&
817           IsAllowedAndFast(PrefAlign)) {
818         Align NewAlign = getOrEnforceKnownAlignment(
819             PtrOperand, PrefAlign, DL, C[CBegin].Inst, nullptr, &DT);
820         if (NewAlign >= Alignment) {
821           LLVM_DEBUG(dbgs()
822                      << "LSV: splitByChain upgrading alloca alignment from "
823                      << Alignment.value() << " to " << NewAlign.value()
824                      << "\n");
825           Alignment = NewAlign;
826         }
827       }
828 
829       if (!IsAllowedAndFast(Alignment)) {
830         LLVM_DEBUG(
831             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
832                       "because its alignment is not AllowedAndFast: "
833                    << Alignment.value() << "\n");
834         continue;
835       }
836 
837       if ((IsLoadChain &&
838            !TTI.isLegalToVectorizeLoadChain(SizeBytes, Alignment, AS)) ||
839           (!IsLoadChain &&
840            !TTI.isLegalToVectorizeStoreChain(SizeBytes, Alignment, AS))) {
841         LLVM_DEBUG(
842             dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
843                       "because !isLegalToVectorizeLoad/StoreChain.");
844         continue;
845       }
846 
847       // Hooray, we can vectorize this chain!
848       Chain &NewChain = Ret.emplace_back();
849       for (unsigned I = CBegin; I <= CEnd; ++I)
850         NewChain.emplace_back(C[I]);
851       CBegin = CEnd; // Skip over the instructions we've added to the chain.
852       break;
853     }
854   }
855   return Ret;
856 }
857 
858 bool Vectorizer::vectorizeChain(Chain &C) {
859   if (C.size() < 2)
860     return false;
861 
862   sortChainInOffsetOrder(C);
863 
864   LLVM_DEBUG({
865     dbgs() << "LSV: Vectorizing chain of " << C.size() << " instructions:\n";
866     dumpChain(C);
867   });
868 
869   Type *VecElemTy = getChainElemTy(C);
870   bool IsLoadChain = isa<LoadInst>(C[0].Inst);
871   unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
872   unsigned ChainBytes = std::accumulate(
873       C.begin(), C.end(), 0u, [&](unsigned Bytes, const ChainElem &E) {
874         return Bytes + DL.getTypeStoreSize(getLoadStoreType(E.Inst));
875       });
876   assert(ChainBytes % DL.getTypeStoreSize(VecElemTy) == 0);
877   // VecTy is a power of 2 and 1 byte at smallest, but VecElemTy may be smaller
878   // than 1 byte (e.g. VecTy == <32 x i1>).
879   Type *VecTy = FixedVectorType::get(
880       VecElemTy, 8 * ChainBytes / DL.getTypeSizeInBits(VecElemTy));
881 
882   Align Alignment = getLoadStoreAlignment(C[0].Inst);
883   // If this is a load/store of an alloca, we might have upgraded the alloca's
884   // alignment earlier.  Get the new alignment.
885   if (AS == DL.getAllocaAddrSpace()) {
886     Alignment = std::max(
887         Alignment,
888         getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst),
889                                    MaybeAlign(), DL, C[0].Inst, nullptr, &DT));
890   }
891 
892   // All elements of the chain must have the same scalar-type size.
893 #ifndef NDEBUG
894   for (const ChainElem &E : C)
895     assert(DL.getTypeStoreSize(getLoadStoreType(E.Inst)->getScalarType()) ==
896            DL.getTypeStoreSize(VecElemTy));
897 #endif
898 
899   Instruction *VecInst;
900   if (IsLoadChain) {
901     // Loads get hoisted to the location of the first load in the chain.  We may
902     // also need to hoist the (transitive) operands of the loads.
903     Builder.SetInsertPoint(
904         llvm::min_element(C, [](const auto &A, const auto &B) {
905           return A.Inst->comesBefore(B.Inst);
906         })->Inst);
907 
908     // Chain is in offset order, so C[0] is the instr with the lowest offset,
909     // i.e. the root of the vector.
910     VecInst = Builder.CreateAlignedLoad(VecTy,
911                                         getLoadStorePointerOperand(C[0].Inst),
912                                         Alignment);
913 
914     unsigned VecIdx = 0;
915     for (const ChainElem &E : C) {
916       Instruction *I = E.Inst;
917       Value *V;
918       Type *T = getLoadStoreType(I);
919       if (auto *VT = dyn_cast<FixedVectorType>(T)) {
920         auto Mask = llvm::to_vector<8>(
921             llvm::seq<int>(VecIdx, VecIdx + VT->getNumElements()));
922         V = Builder.CreateShuffleVector(VecInst, Mask, I->getName());
923         VecIdx += VT->getNumElements();
924       } else {
925         V = Builder.CreateExtractElement(VecInst, Builder.getInt32(VecIdx),
926                                          I->getName());
927         ++VecIdx;
928       }
929       if (V->getType() != I->getType())
930         V = Builder.CreateBitOrPointerCast(V, I->getType());
931       I->replaceAllUsesWith(V);
932     }
933 
934     // Finally, we need to reorder the instrs in the BB so that the (transitive)
935     // operands of VecInst appear before it.  To see why, suppose we have
936     // vectorized the following code:
937     //
938     //   ptr1  = gep a, 1
939     //   load1 = load i32 ptr1
940     //   ptr0  = gep a, 0
941     //   load0 = load i32 ptr0
942     //
943     // We will put the vectorized load at the location of the earliest load in
944     // the BB, i.e. load1.  We get:
945     //
946     //   ptr1  = gep a, 1
947     //   loadv = load <2 x i32> ptr0
948     //   load0 = extractelement loadv, 0
949     //   load1 = extractelement loadv, 1
950     //   ptr0 = gep a, 0
951     //
952     // Notice that loadv uses ptr0, which is defined *after* it!
953     reorder(VecInst);
954   } else {
955     // Stores get sunk to the location of the last store in the chain.
956     Builder.SetInsertPoint(llvm::max_element(C, [](auto &A, auto &B) {
957                              return A.Inst->comesBefore(B.Inst);
958                            })->Inst);
959 
960     // Build the vector to store.
961     Value *Vec = PoisonValue::get(VecTy);
962     unsigned VecIdx = 0;
963     auto InsertElem = [&](Value *V) {
964       if (V->getType() != VecElemTy)
965         V = Builder.CreateBitOrPointerCast(V, VecElemTy);
966       Vec = Builder.CreateInsertElement(Vec, V, Builder.getInt32(VecIdx++));
967     };
968     for (const ChainElem &E : C) {
969       auto *I = cast<StoreInst>(E.Inst);
970       if (FixedVectorType *VT =
971               dyn_cast<FixedVectorType>(getLoadStoreType(I))) {
972         for (int J = 0, JE = VT->getNumElements(); J < JE; ++J) {
973           InsertElem(Builder.CreateExtractElement(I->getValueOperand(),
974                                                   Builder.getInt32(J)));
975         }
976       } else {
977         InsertElem(I->getValueOperand());
978       }
979     }
980 
981     // Chain is in offset order, so C[0] is the instr with the lowest offset,
982     // i.e. the root of the vector.
983     VecInst = Builder.CreateAlignedStore(
984         Vec,
985         getLoadStorePointerOperand(C[0].Inst),
986         Alignment);
987   }
988 
989   propagateMetadata(VecInst, C);
990 
991   for (const ChainElem &E : C)
992     ToErase.emplace_back(E.Inst);
993 
994   ++NumVectorInstructions;
995   NumScalarsVectorized += C.size();
996   return true;
997 }
998 
999 template <bool IsLoadChain>
1000 bool Vectorizer::isSafeToMove(
1001     Instruction *ChainElem, Instruction *ChainBegin,
1002     const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets) {
1003   LLVM_DEBUG(dbgs() << "LSV: isSafeToMove(" << *ChainElem << " -> "
1004                     << *ChainBegin << ")\n");
1005 
1006   assert(isa<LoadInst>(ChainElem) == IsLoadChain);
1007   if (ChainElem == ChainBegin)
1008     return true;
1009 
1010   // Invariant loads can always be reordered; by definition they are not
1011   // clobbered by stores.
1012   if (isInvariantLoad(ChainElem))
1013     return true;
1014 
1015   auto BBIt = std::next([&] {
1016     if constexpr (IsLoadChain)
1017       return BasicBlock::reverse_iterator(ChainElem);
1018     else
1019       return BasicBlock::iterator(ChainElem);
1020   }());
1021   auto BBItEnd = std::next([&] {
1022     if constexpr (IsLoadChain)
1023       return BasicBlock::reverse_iterator(ChainBegin);
1024     else
1025       return BasicBlock::iterator(ChainBegin);
1026   }());
1027 
1028   const APInt &ChainElemOffset = ChainOffsets.at(ChainElem);
1029   const unsigned ChainElemSize =
1030       DL.getTypeStoreSize(getLoadStoreType(ChainElem));
1031 
1032   for (; BBIt != BBItEnd; ++BBIt) {
1033     Instruction *I = &*BBIt;
1034 
1035     if (!I->mayReadOrWriteMemory())
1036       continue;
1037 
1038     // Loads can be reordered with other loads.
1039     if (IsLoadChain && isa<LoadInst>(I))
1040       continue;
1041 
1042     // Stores can be sunk below invariant loads.
1043     if (!IsLoadChain && isInvariantLoad(I))
1044       continue;
1045 
1046     // If I is in the chain, we can tell whether it aliases ChainIt by checking
1047     // what offset ChainIt accesses.  This may be better than AA is able to do.
1048     //
1049     // We should really only have duplicate offsets for stores (the duplicate
1050     // loads should be CSE'ed), but in case we have a duplicate load, we'll
1051     // split the chain so we don't have to handle this case specially.
1052     if (auto OffsetIt = ChainOffsets.find(I); OffsetIt != ChainOffsets.end()) {
1053       // I and ChainElem overlap if:
1054       //   - I and ChainElem have the same offset, OR
1055       //   - I's offset is less than ChainElem's, but I touches past the
1056       //     beginning of ChainElem, OR
1057       //   - ChainElem's offset is less than I's, but ChainElem touches past the
1058       //     beginning of I.
1059       const APInt &IOffset = OffsetIt->second;
1060       unsigned IElemSize = DL.getTypeStoreSize(getLoadStoreType(I));
1061       if (IOffset == ChainElemOffset ||
1062           (IOffset.sle(ChainElemOffset) &&
1063            (IOffset + IElemSize).sgt(ChainElemOffset)) ||
1064           (ChainElemOffset.sle(IOffset) &&
1065            (ChainElemOffset + ChainElemSize).sgt(OffsetIt->second))) {
1066         LLVM_DEBUG({
1067           // Double check that AA also sees this alias.  If not, we probably
1068           // have a bug.
1069           ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
1070           assert(IsLoadChain ? isModSet(MR) : isModOrRefSet(MR));
1071           dbgs() << "LSV: Found alias in chain: " << *I << "\n";
1072         });
1073         return false; // We found an aliasing instruction; bail.
1074       }
1075 
1076       continue; // We're confident there's no alias.
1077     }
1078 
1079     LLVM_DEBUG(dbgs() << "LSV: Querying AA for " << *I << "\n");
1080     ModRefInfo MR = AA.getModRefInfo(I, MemoryLocation::get(ChainElem));
1081     if (IsLoadChain ? isModSet(MR) : isModOrRefSet(MR)) {
1082       LLVM_DEBUG(dbgs() << "LSV: Found alias in chain:\n"
1083                         << "  Aliasing instruction:\n"
1084                         << "    " << *I << '\n'
1085                         << "  Aliased instruction and pointer:\n"
1086                         << "    " << *ChainElem << '\n'
1087                         << "    " << *getLoadStorePointerOperand(ChainElem)
1088                         << '\n');
1089 
1090       return false;
1091     }
1092   }
1093   return true;
1094 }
1095 
1096 static bool checkNoWrapFlags(Instruction *I, bool Signed) {
1097   BinaryOperator *BinOpI = cast<BinaryOperator>(I);
1098   return (Signed && BinOpI->hasNoSignedWrap()) ||
1099          (!Signed && BinOpI->hasNoUnsignedWrap());
1100 }
1101 
1102 static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
1103                                    unsigned MatchingOpIdxA, Instruction *AddOpB,
1104                                    unsigned MatchingOpIdxB, bool Signed) {
1105   LLVM_DEBUG(dbgs() << "LSV: checkIfSafeAddSequence IdxDiff=" << IdxDiff
1106                     << ", AddOpA=" << *AddOpA << ", MatchingOpIdxA="
1107                     << MatchingOpIdxA << ", AddOpB=" << *AddOpB
1108                     << ", MatchingOpIdxB=" << MatchingOpIdxB
1109                     << ", Signed=" << Signed << "\n");
1110   // If both OpA and OpB are adds with NSW/NUW and with one of the operands
1111   // being the same, we can guarantee that the transformation is safe if we can
1112   // prove that OpA won't overflow when Ret added to the other operand of OpA.
1113   // For example:
1114   //  %tmp7 = add nsw i32 %tmp2, %v0
1115   //  %tmp8 = sext i32 %tmp7 to i64
1116   //  ...
1117   //  %tmp11 = add nsw i32 %v0, 1
1118   //  %tmp12 = add nsw i32 %tmp2, %tmp11
1119   //  %tmp13 = sext i32 %tmp12 to i64
1120   //
1121   //  Both %tmp7 and %tmp12 have the nsw flag and the first operand is %tmp2.
1122   //  It's guaranteed that adding 1 to %tmp7 won't overflow because %tmp11 adds
1123   //  1 to %v0 and both %tmp11 and %tmp12 have the nsw flag.
1124   assert(AddOpA->getOpcode() == Instruction::Add &&
1125          AddOpB->getOpcode() == Instruction::Add &&
1126          checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
1127   if (AddOpA->getOperand(MatchingOpIdxA) ==
1128       AddOpB->getOperand(MatchingOpIdxB)) {
1129     Value *OtherOperandA = AddOpA->getOperand(MatchingOpIdxA == 1 ? 0 : 1);
1130     Value *OtherOperandB = AddOpB->getOperand(MatchingOpIdxB == 1 ? 0 : 1);
1131     Instruction *OtherInstrA = dyn_cast<Instruction>(OtherOperandA);
1132     Instruction *OtherInstrB = dyn_cast<Instruction>(OtherOperandB);
1133     // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
1134     if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add &&
1135         checkNoWrapFlags(OtherInstrB, Signed) &&
1136         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
1137       int64_t CstVal =
1138           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
1139       if (OtherInstrB->getOperand(0) == OtherOperandA &&
1140           IdxDiff.getSExtValue() == CstVal)
1141         return true;
1142     }
1143     // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
1144     if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add &&
1145         checkNoWrapFlags(OtherInstrA, Signed) &&
1146         isa<ConstantInt>(OtherInstrA->getOperand(1))) {
1147       int64_t CstVal =
1148           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
1149       if (OtherInstrA->getOperand(0) == OtherOperandB &&
1150           IdxDiff.getSExtValue() == -CstVal)
1151         return true;
1152     }
1153     // Match `x +nsw/nuw (y +nsw/nuw c)` and
1154     // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
1155     if (OtherInstrA && OtherInstrB &&
1156         OtherInstrA->getOpcode() == Instruction::Add &&
1157         OtherInstrB->getOpcode() == Instruction::Add &&
1158         checkNoWrapFlags(OtherInstrA, Signed) &&
1159         checkNoWrapFlags(OtherInstrB, Signed) &&
1160         isa<ConstantInt>(OtherInstrA->getOperand(1)) &&
1161         isa<ConstantInt>(OtherInstrB->getOperand(1))) {
1162       int64_t CstValA =
1163           cast<ConstantInt>(OtherInstrA->getOperand(1))->getSExtValue();
1164       int64_t CstValB =
1165           cast<ConstantInt>(OtherInstrB->getOperand(1))->getSExtValue();
1166       if (OtherInstrA->getOperand(0) == OtherInstrB->getOperand(0) &&
1167           IdxDiff.getSExtValue() == (CstValB - CstValA))
1168         return true;
1169     }
1170   }
1171   return false;
1172 }
1173 
1174 std::optional<APInt> Vectorizer::getConstantOffsetComplexAddrs(
1175     Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
1176   LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs PtrA=" << *PtrA
1177                     << " PtrB=" << *PtrB << " ContextInst=" << *ContextInst
1178                     << " Depth=" << Depth << "\n");
1179   auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA);
1180   auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB);
1181   if (!GEPA || !GEPB)
1182     return getConstantOffsetSelects(PtrA, PtrB, ContextInst, Depth);
1183 
1184   // Look through GEPs after checking they're the same except for the last
1185   // index.
1186   if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
1187       GEPA->getPointerOperand() != GEPB->getPointerOperand())
1188     return std::nullopt;
1189   gep_type_iterator GTIA = gep_type_begin(GEPA);
1190   gep_type_iterator GTIB = gep_type_begin(GEPB);
1191   for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
1192     if (GTIA.getOperand() != GTIB.getOperand())
1193       return std::nullopt;
1194     ++GTIA;
1195     ++GTIB;
1196   }
1197 
1198   Instruction *OpA = dyn_cast<Instruction>(GTIA.getOperand());
1199   Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand());
1200   if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
1201       OpA->getType() != OpB->getType())
1202     return std::nullopt;
1203 
1204   uint64_t Stride = GTIA.getSequentialElementStride(DL);
1205 
1206   // Only look through a ZExt/SExt.
1207   if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA))
1208     return std::nullopt;
1209 
1210   bool Signed = isa<SExtInst>(OpA);
1211 
1212   // At this point A could be a function parameter, i.e. not an instruction
1213   Value *ValA = OpA->getOperand(0);
1214   OpB = dyn_cast<Instruction>(OpB->getOperand(0));
1215   if (!OpB || ValA->getType() != OpB->getType())
1216     return std::nullopt;
1217 
1218   const SCEV *OffsetSCEVA = SE.getSCEV(ValA);
1219   const SCEV *OffsetSCEVB = SE.getSCEV(OpB);
1220   const SCEV *IdxDiffSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA);
1221   if (IdxDiffSCEV == SE.getCouldNotCompute())
1222     return std::nullopt;
1223 
1224   ConstantRange IdxDiffRange = SE.getSignedRange(IdxDiffSCEV);
1225   if (!IdxDiffRange.isSingleElement())
1226     return std::nullopt;
1227   APInt IdxDiff = *IdxDiffRange.getSingleElement();
1228 
1229   LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs IdxDiff=" << IdxDiff
1230                     << "\n");
1231 
1232   // Now we need to prove that adding IdxDiff to ValA won't overflow.
1233   bool Safe = false;
1234 
1235   // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
1236   // ValA, we're okay.
1237   if (OpB->getOpcode() == Instruction::Add &&
1238       isa<ConstantInt>(OpB->getOperand(1)) &&
1239       IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
1240       checkNoWrapFlags(OpB, Signed))
1241     Safe = true;
1242 
1243   // Second attempt: check if we have eligible add NSW/NUW instruction
1244   // sequences.
1245   OpA = dyn_cast<Instruction>(ValA);
1246   if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
1247       OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) &&
1248       checkNoWrapFlags(OpB, Signed)) {
1249     // In the checks below a matching operand in OpA and OpB is an operand which
1250     // is the same in those two instructions.  Below we account for possible
1251     // orders of the operands of these add instructions.
1252     for (unsigned MatchingOpIdxA : {0, 1})
1253       for (unsigned MatchingOpIdxB : {0, 1})
1254         if (!Safe)
1255           Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOpIdxA, OpB,
1256                                         MatchingOpIdxB, Signed);
1257   }
1258 
1259   unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
1260 
1261   // Third attempt:
1262   //
1263   // Assuming IdxDiff is positive: If all set bits of IdxDiff or any higher
1264   // order bit other than the sign bit are known to be zero in ValA, we can add
1265   // Diff to it while guaranteeing no overflow of any sort.
1266   //
1267   // If IdxDiff is negative, do the same, but swap ValA and ValB.
1268   if (!Safe) {
1269     // When computing known bits, use the GEPs as context instructions, since
1270     // they likely are in the same BB as the load/store.
1271     KnownBits Known(BitWidth);
1272     computeKnownBits((IdxDiff.sge(0) ? ValA : OpB), Known, DL, 0, &AC,
1273                      ContextInst, &DT);
1274     APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth());
1275     if (Signed)
1276       BitsAllowedToBeSet.clearBit(BitWidth - 1);
1277     if (BitsAllowedToBeSet.ult(IdxDiff.abs()))
1278       return std::nullopt;
1279     Safe = true;
1280   }
1281 
1282   if (Safe)
1283     return IdxDiff * Stride;
1284   return std::nullopt;
1285 }
1286 
1287 std::optional<APInt> Vectorizer::getConstantOffsetSelects(
1288     Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
1289   if (Depth++ == MaxDepth)
1290     return std::nullopt;
1291 
1292   if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) {
1293     if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) {
1294       if (SelectA->getCondition() != SelectB->getCondition())
1295         return std::nullopt;
1296       LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetSelects, PtrA=" << *PtrA
1297                         << ", PtrB=" << *PtrB << ", ContextInst="
1298                         << *ContextInst << ", Depth=" << Depth << "\n");
1299       std::optional<APInt> TrueDiff = getConstantOffset(
1300           SelectA->getTrueValue(), SelectB->getTrueValue(), ContextInst, Depth);
1301       if (!TrueDiff)
1302         return std::nullopt;
1303       std::optional<APInt> FalseDiff =
1304           getConstantOffset(SelectA->getFalseValue(), SelectB->getFalseValue(),
1305                             ContextInst, Depth);
1306       if (TrueDiff == FalseDiff)
1307         return TrueDiff;
1308     }
1309   }
1310   return std::nullopt;
1311 }
1312 
1313 void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
1314   if (EQClasses.size() < 2) // There is nothing to merge.
1315     return;
1316 
1317   // The reduced key has all elements of the ECClassKey except the underlying
1318   // object. Check that EqClassKey has 4 elements and define the reduced key.
1319   static_assert(std::tuple_size_v<EqClassKey> == 4,
1320                 "EqClassKey has changed - EqClassReducedKey needs changes too");
1321   using EqClassReducedKey =
1322       std::tuple<std::tuple_element_t<1, EqClassKey> /* AddrSpace */,
1323                  std::tuple_element_t<2, EqClassKey> /* Element size */,
1324                  std::tuple_element_t<3, EqClassKey> /* IsLoad; */>;
1325   using ECReducedKeyToUnderlyingObjectMap =
1326       MapVector<EqClassReducedKey,
1327                 SmallPtrSet<std::tuple_element_t<0, EqClassKey>, 4>>;
1328 
1329   // Form a map from the reduced key (without the underlying object) to the
1330   // underlying objects: 1 reduced key to many underlying objects, to form
1331   // groups of potentially merge-able equivalence classes.
1332   ECReducedKeyToUnderlyingObjectMap RedKeyToUOMap;
1333   bool FoundPotentiallyOptimizableEC = false;
1334   for (const auto &EC : EQClasses) {
1335     const auto &Key = EC.first;
1336     EqClassReducedKey RedKey{std::get<1>(Key), std::get<2>(Key),
1337                              std::get<3>(Key)};
1338     RedKeyToUOMap[RedKey].insert(std::get<0>(Key));
1339     if (RedKeyToUOMap[RedKey].size() > 1)
1340       FoundPotentiallyOptimizableEC = true;
1341   }
1342   if (!FoundPotentiallyOptimizableEC)
1343     return;
1344 
1345   LLVM_DEBUG({
1346     dbgs() << "LSV: mergeEquivalenceClasses: before merging:\n";
1347     for (const auto &EC : EQClasses) {
1348       dbgs() << "  Key: {" << EC.first << "}\n";
1349       for (const auto &Inst : EC.second)
1350         dbgs() << "    Inst: " << *Inst << '\n';
1351     }
1352   });
1353   LLVM_DEBUG({
1354     dbgs() << "LSV: mergeEquivalenceClasses: RedKeyToUOMap:\n";
1355     for (const auto &RedKeyToUO : RedKeyToUOMap) {
1356       dbgs() << "  Reduced key: {" << std::get<0>(RedKeyToUO.first) << ", "
1357              << std::get<1>(RedKeyToUO.first) << ", "
1358              << static_cast<int>(std::get<2>(RedKeyToUO.first)) << "} --> "
1359              << RedKeyToUO.second.size() << " underlying objects:\n";
1360       for (auto UObject : RedKeyToUO.second)
1361         dbgs() << "    " << *UObject << '\n';
1362     }
1363   });
1364 
1365   using UObjectToUObjectMap = DenseMap<const Value *, const Value *>;
1366 
1367   // Compute the ultimate targets for a set of underlying objects.
1368   auto GetUltimateTargets =
1369       [](SmallPtrSetImpl<const Value *> &UObjects) -> UObjectToUObjectMap {
1370     UObjectToUObjectMap IndirectionMap;
1371     for (const auto *UObject : UObjects) {
1372       const unsigned MaxLookupDepth = 1; // look for 1-level indirections only
1373       const auto *UltimateTarget = getUnderlyingObject(UObject, MaxLookupDepth);
1374       if (UltimateTarget != UObject)
1375         IndirectionMap[UObject] = UltimateTarget;
1376     }
1377     UObjectToUObjectMap UltimateTargetsMap;
1378     for (const auto *UObject : UObjects) {
1379       auto Target = UObject;
1380       auto It = IndirectionMap.find(Target);
1381       for (; It != IndirectionMap.end(); It = IndirectionMap.find(Target))
1382         Target = It->second;
1383       UltimateTargetsMap[UObject] = Target;
1384     }
1385     return UltimateTargetsMap;
1386   };
1387 
1388   // For each item in RedKeyToUOMap, if it has more than one underlying object,
1389   // try to merge the equivalence classes.
1390   for (auto &[RedKey, UObjects] : RedKeyToUOMap) {
1391     if (UObjects.size() < 2)
1392       continue;
1393     auto UTMap = GetUltimateTargets(UObjects);
1394     for (const auto &[UObject, UltimateTarget] : UTMap) {
1395       if (UObject == UltimateTarget)
1396         continue;
1397 
1398       EqClassKey KeyFrom{UObject, std::get<0>(RedKey), std::get<1>(RedKey),
1399                          std::get<2>(RedKey)};
1400       EqClassKey KeyTo{UltimateTarget, std::get<0>(RedKey), std::get<1>(RedKey),
1401                        std::get<2>(RedKey)};
1402       // The entry for KeyFrom is guarantted to exist, unlike KeyTo. Thus,
1403       // request the reference to the instructions vector for KeyTo first.
1404       const auto &VecTo = EQClasses[KeyTo];
1405       const auto &VecFrom = EQClasses[KeyFrom];
1406       SmallVector<Instruction *, 8> MergedVec;
1407       std::merge(VecFrom.begin(), VecFrom.end(), VecTo.begin(), VecTo.end(),
1408                  std::back_inserter(MergedVec),
1409                  [](Instruction *A, Instruction *B) {
1410                    return A && B && A->comesBefore(B);
1411                  });
1412       EQClasses[KeyTo] = std::move(MergedVec);
1413       EQClasses.erase(KeyFrom);
1414     }
1415   }
1416   LLVM_DEBUG({
1417     dbgs() << "LSV: mergeEquivalenceClasses: after merging:\n";
1418     for (const auto &EC : EQClasses) {
1419       dbgs() << "  Key: {" << EC.first << "}\n";
1420       for (const auto &Inst : EC.second)
1421         dbgs() << "    Inst: " << *Inst << '\n';
1422     }
1423   });
1424 }
1425 
1426 EquivalenceClassMap
1427 Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
1428                                       BasicBlock::iterator End) {
1429   EquivalenceClassMap Ret;
1430 
1431   auto GetUnderlyingObject = [](const Value *Ptr) -> const Value * {
1432     const Value *ObjPtr = llvm::getUnderlyingObject(Ptr);
1433     if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) {
1434       // The select's themselves are distinct instructions even if they share
1435       // the same condition and evaluate to consecutive pointers for true and
1436       // false values of the condition. Therefore using the select's themselves
1437       // for grouping instructions would put consecutive accesses into different
1438       // lists and they won't be even checked for being consecutive, and won't
1439       // be vectorized.
1440       return Sel->getCondition();
1441     }
1442     return ObjPtr;
1443   };
1444 
1445   for (Instruction &I : make_range(Begin, End)) {
1446     auto *LI = dyn_cast<LoadInst>(&I);
1447     auto *SI = dyn_cast<StoreInst>(&I);
1448     if (!LI && !SI)
1449       continue;
1450 
1451     if ((LI && !LI->isSimple()) || (SI && !SI->isSimple()))
1452       continue;
1453 
1454     if ((LI && !TTI.isLegalToVectorizeLoad(LI)) ||
1455         (SI && !TTI.isLegalToVectorizeStore(SI)))
1456       continue;
1457 
1458     Type *Ty = getLoadStoreType(&I);
1459     if (!VectorType::isValidElementType(Ty->getScalarType()))
1460       continue;
1461 
1462     // Skip weird non-byte sizes. They probably aren't worth the effort of
1463     // handling correctly.
1464     unsigned TySize = DL.getTypeSizeInBits(Ty);
1465     if ((TySize % 8) != 0)
1466       continue;
1467 
1468     // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
1469     // functions are currently using an integer type for the vectorized
1470     // load/store, and does not support casting between the integer type and a
1471     // vector of pointers (e.g. i64 to <2 x i16*>)
1472     if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
1473       continue;
1474 
1475     Value *Ptr = getLoadStorePointerOperand(&I);
1476     unsigned AS = Ptr->getType()->getPointerAddressSpace();
1477     unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS);
1478 
1479     unsigned VF = VecRegSize / TySize;
1480     VectorType *VecTy = dyn_cast<VectorType>(Ty);
1481 
1482     // Only handle power-of-two sized elements.
1483     if ((!VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(Ty))) ||
1484         (VecTy && !isPowerOf2_32(DL.getTypeSizeInBits(VecTy->getScalarType()))))
1485       continue;
1486 
1487     // No point in looking at these if they're too big to vectorize.
1488     if (TySize > VecRegSize / 2 ||
1489         (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0))
1490       continue;
1491 
1492     Ret[{GetUnderlyingObject(Ptr), AS,
1493          DL.getTypeSizeInBits(getLoadStoreType(&I)->getScalarType()),
1494          /*IsLoad=*/LI != nullptr}]
1495         .emplace_back(&I);
1496   }
1497 
1498   mergeEquivalenceClasses(Ret);
1499   return Ret;
1500 }
1501 
1502 std::vector<Chain> Vectorizer::gatherChains(ArrayRef<Instruction *> Instrs) {
1503   if (Instrs.empty())
1504     return {};
1505 
1506   unsigned AS = getLoadStoreAddressSpace(Instrs[0]);
1507   unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
1508 
1509 #ifndef NDEBUG
1510   // Check that Instrs is in BB order and all have the same addr space.
1511   for (size_t I = 1; I < Instrs.size(); ++I) {
1512     assert(Instrs[I - 1]->comesBefore(Instrs[I]));
1513     assert(getLoadStoreAddressSpace(Instrs[I]) == AS);
1514   }
1515 #endif
1516 
1517   // Machinery to build an MRU-hashtable of Chains.
1518   //
1519   // (Ideally this could be done with MapVector, but as currently implemented,
1520   // moving an element to the front of a MapVector is O(n).)
1521   struct InstrListElem : ilist_node<InstrListElem>,
1522                          std::pair<Instruction *, Chain> {
1523     explicit InstrListElem(Instruction *I)
1524         : std::pair<Instruction *, Chain>(I, {}) {}
1525   };
1526   struct InstrListElemDenseMapInfo {
1527     using PtrInfo = DenseMapInfo<InstrListElem *>;
1528     using IInfo = DenseMapInfo<Instruction *>;
1529     static InstrListElem *getEmptyKey() { return PtrInfo::getEmptyKey(); }
1530     static InstrListElem *getTombstoneKey() {
1531       return PtrInfo::getTombstoneKey();
1532     }
1533     static unsigned getHashValue(const InstrListElem *E) {
1534       return IInfo::getHashValue(E->first);
1535     }
1536     static bool isEqual(const InstrListElem *A, const InstrListElem *B) {
1537       if (A == getEmptyKey() || B == getEmptyKey())
1538         return A == getEmptyKey() && B == getEmptyKey();
1539       if (A == getTombstoneKey() || B == getTombstoneKey())
1540         return A == getTombstoneKey() && B == getTombstoneKey();
1541       return IInfo::isEqual(A->first, B->first);
1542     }
1543   };
1544   SpecificBumpPtrAllocator<InstrListElem> Allocator;
1545   simple_ilist<InstrListElem> MRU;
1546   DenseSet<InstrListElem *, InstrListElemDenseMapInfo> Chains;
1547 
1548   // Compare each instruction in `instrs` to leader of the N most recently-used
1549   // chains.  This limits the O(n^2) behavior of this pass while also allowing
1550   // us to build arbitrarily long chains.
1551   for (Instruction *I : Instrs) {
1552     constexpr int MaxChainsToTry = 64;
1553 
1554     bool MatchFound = false;
1555     auto ChainIter = MRU.begin();
1556     for (size_t J = 0; J < MaxChainsToTry && ChainIter != MRU.end();
1557          ++J, ++ChainIter) {
1558       if (std::optional<APInt> Offset = getConstantOffset(
1559               getLoadStorePointerOperand(ChainIter->first),
1560               getLoadStorePointerOperand(I),
1561               /*ContextInst=*/
1562               (ChainIter->first->comesBefore(I) ? I : ChainIter->first))) {
1563         // `Offset` might not have the expected number of bits, if e.g. AS has a
1564         // different number of bits than opaque pointers.
1565         ChainIter->second.emplace_back(I, Offset.value());
1566         // Move ChainIter to the front of the MRU list.
1567         MRU.remove(*ChainIter);
1568         MRU.push_front(*ChainIter);
1569         MatchFound = true;
1570         break;
1571       }
1572     }
1573 
1574     if (!MatchFound) {
1575       APInt ZeroOffset(ASPtrBits, 0);
1576       InstrListElem *E = new (Allocator.Allocate()) InstrListElem(I);
1577       E->second.emplace_back(I, ZeroOffset);
1578       MRU.push_front(*E);
1579       Chains.insert(E);
1580     }
1581   }
1582 
1583   std::vector<Chain> Ret;
1584   Ret.reserve(Chains.size());
1585   // Iterate over MRU rather than Chains so the order is deterministic.
1586   for (auto &E : MRU)
1587     if (E.second.size() > 1)
1588       Ret.emplace_back(std::move(E.second));
1589   return Ret;
1590 }
1591 
1592 std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
1593                                                    Instruction *ContextInst,
1594                                                    unsigned Depth) {
1595   LLVM_DEBUG(dbgs() << "LSV: getConstantOffset, PtrA=" << *PtrA
1596                     << ", PtrB=" << *PtrB << ", ContextInst= " << *ContextInst
1597                     << ", Depth=" << Depth << "\n");
1598   // We'll ultimately return a value of this bit width, even if computations
1599   // happen in a different width.
1600   unsigned OrigBitWidth = DL.getIndexTypeSizeInBits(PtrA->getType());
1601   APInt OffsetA(OrigBitWidth, 0);
1602   APInt OffsetB(OrigBitWidth, 0);
1603   PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
1604   PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB);
1605   unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(PtrA->getType());
1606   if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(PtrB->getType()))
1607     return std::nullopt;
1608 
1609   // If we have to shrink the pointer, stripAndAccumulateInBoundsConstantOffsets
1610   // should properly handle a possible overflow and the value should fit into
1611   // the smallest data type used in the cast/gep chain.
1612   assert(OffsetA.getSignificantBits() <= NewPtrBitWidth &&
1613          OffsetB.getSignificantBits() <= NewPtrBitWidth);
1614 
1615   OffsetA = OffsetA.sextOrTrunc(NewPtrBitWidth);
1616   OffsetB = OffsetB.sextOrTrunc(NewPtrBitWidth);
1617   if (PtrA == PtrB)
1618     return (OffsetB - OffsetA).sextOrTrunc(OrigBitWidth);
1619 
1620   // Try to compute B - A.
1621   const SCEV *DistScev = SE.getMinusSCEV(SE.getSCEV(PtrB), SE.getSCEV(PtrA));
1622   if (DistScev != SE.getCouldNotCompute()) {
1623     LLVM_DEBUG(dbgs() << "LSV: SCEV PtrB - PtrA =" << *DistScev << "\n");
1624     ConstantRange DistRange = SE.getSignedRange(DistScev);
1625     if (DistRange.isSingleElement()) {
1626       // Handle index width (the width of Dist) != pointer width (the width of
1627       // the Offset*s at this point).
1628       APInt Dist = DistRange.getSingleElement()->sextOrTrunc(NewPtrBitWidth);
1629       return (OffsetB - OffsetA + Dist).sextOrTrunc(OrigBitWidth);
1630     }
1631   }
1632   if (std::optional<APInt> Diff =
1633           getConstantOffsetComplexAddrs(PtrA, PtrB, ContextInst, Depth))
1634     return (OffsetB - OffsetA + Diff->sext(OffsetB.getBitWidth()))
1635         .sextOrTrunc(OrigBitWidth);
1636   return std::nullopt;
1637 }
1638