xref: /llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp (revision bf873aa3ecef93c8dc8eb792da9e73ceff120492)
1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/BasicAliasAnalysis.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/Loads.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/VectorUtils.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/PatternMatch.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include "llvm/Transforms/Utils/LoopUtils.h"
34 #include <numeric>
35 #include <queue>
36 
37 #define DEBUG_TYPE "vector-combine"
38 #include "llvm/Transforms/Utils/InstructionWorklist.h"
39 
40 using namespace llvm;
41 using namespace llvm::PatternMatch;
42 
43 STATISTIC(NumVecLoad, "Number of vector loads formed");
44 STATISTIC(NumVecCmp, "Number of vector compares formed");
45 STATISTIC(NumVecBO, "Number of vector binops formed");
46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
48 STATISTIC(NumScalarBO, "Number of scalar binops formed");
49 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
50 
51 static cl::opt<bool> DisableVectorCombine(
52     "disable-vector-combine", cl::init(false), cl::Hidden,
53     cl::desc("Disable all vector combine transforms"));
54 
55 static cl::opt<bool> DisableBinopExtractShuffle(
56     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
57     cl::desc("Disable binop extract to shuffle transforms"));
58 
59 static cl::opt<unsigned> MaxInstrsToScan(
60     "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
61     cl::desc("Max number of instructions to scan for vector combining."));
62 
63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
64 
65 namespace {
66 class VectorCombine {
67 public:
68   VectorCombine(Function &F, const TargetTransformInfo &TTI,
69                 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
70                 const DataLayout *DL, TTI::TargetCostKind CostKind,
71                 bool TryEarlyFoldsOnly)
72       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL),
73         CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
74 
75   bool run();
76 
77 private:
78   Function &F;
79   IRBuilder<> Builder;
80   const TargetTransformInfo &TTI;
81   const DominatorTree &DT;
82   AAResults &AA;
83   AssumptionCache &AC;
84   const DataLayout *DL;
85   TTI::TargetCostKind CostKind;
86 
87   /// If true, only perform beneficial early IR transforms. Do not introduce new
88   /// vector operations.
89   bool TryEarlyFoldsOnly;
90 
91   InstructionWorklist Worklist;
92 
93   // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
94   //       parameter. That should be updated to specific sub-classes because the
95   //       run loop was changed to dispatch on opcode.
96   bool vectorizeLoadInsert(Instruction &I);
97   bool widenSubvectorLoad(Instruction &I);
98   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
99                                         ExtractElementInst *Ext1,
100                                         unsigned PreferredExtractIndex) const;
101   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
102                              const Instruction &I,
103                              ExtractElementInst *&ConvertToShuffle,
104                              unsigned PreferredExtractIndex);
105   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
106                      Instruction &I);
107   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
108                        Instruction &I);
109   bool foldExtractExtract(Instruction &I);
110   bool foldInsExtFNeg(Instruction &I);
111   bool foldInsExtVectorToShuffle(Instruction &I);
112   bool foldBitcastShuffle(Instruction &I);
113   bool scalarizeBinopOrCmp(Instruction &I);
114   bool scalarizeVPIntrinsic(Instruction &I);
115   bool foldExtractedCmps(Instruction &I);
116   bool foldSingleElementStore(Instruction &I);
117   bool scalarizeLoadExtract(Instruction &I);
118   bool foldConcatOfBoolMasks(Instruction &I);
119   bool foldPermuteOfBinops(Instruction &I);
120   bool foldShuffleOfBinops(Instruction &I);
121   bool foldShuffleOfCastops(Instruction &I);
122   bool foldShuffleOfShuffles(Instruction &I);
123   bool foldShuffleOfIntrinsics(Instruction &I);
124   bool foldShuffleToIdentity(Instruction &I);
125   bool foldShuffleFromReductions(Instruction &I);
126   bool foldCastFromReductions(Instruction &I);
127   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
128   bool shrinkType(Instruction &I);
129 
130   void replaceValue(Value &Old, Value &New) {
131     Old.replaceAllUsesWith(&New);
132     if (auto *NewI = dyn_cast<Instruction>(&New)) {
133       New.takeName(&Old);
134       Worklist.pushUsersToWorkList(*NewI);
135       Worklist.pushValue(NewI);
136     }
137     Worklist.pushValue(&Old);
138   }
139 
140   void eraseInstruction(Instruction &I) {
141     LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
142     for (Value *Op : I.operands())
143       Worklist.pushValue(Op);
144     Worklist.remove(&I);
145     I.eraseFromParent();
146   }
147 };
148 } // namespace
149 
150 /// Return the source operand of a potentially bitcasted value. If there is no
151 /// bitcast, return the input value itself.
152 static Value *peekThroughBitcasts(Value *V) {
153   while (auto *BitCast = dyn_cast<BitCastInst>(V))
154     V = BitCast->getOperand(0);
155   return V;
156 }
157 
158 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
159   // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
160   // The widened load may load data from dirty regions or create data races
161   // non-existent in the source.
162   if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
163       Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
164       mustSuppressSpeculation(*Load))
165     return false;
166 
167   // We are potentially transforming byte-sized (8-bit) memory accesses, so make
168   // sure we have all of our type-based constraints in place for this target.
169   Type *ScalarTy = Load->getType()->getScalarType();
170   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
171   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
172   if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
173       ScalarSize % 8 != 0)
174     return false;
175 
176   return true;
177 }
178 
179 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
180   // Match insert into fixed vector of scalar value.
181   // TODO: Handle non-zero insert index.
182   Value *Scalar;
183   if (!match(&I,
184              m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt())))
185     return false;
186 
187   // Optionally match an extract from another vector.
188   Value *X;
189   bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
190   if (!HasExtract)
191     X = Scalar;
192 
193   auto *Load = dyn_cast<LoadInst>(X);
194   if (!canWidenLoad(Load, TTI))
195     return false;
196 
197   Type *ScalarTy = Scalar->getType();
198   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
199   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
200 
201   // Check safety of replacing the scalar load with a larger vector load.
202   // We use minimal alignment (maximum flexibility) because we only care about
203   // the dereferenceable region. When calculating cost and creating a new op,
204   // we may use a larger value based on alignment attributes.
205   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
206   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
207 
208   unsigned MinVecNumElts = MinVectorSize / ScalarSize;
209   auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
210   unsigned OffsetEltIndex = 0;
211   Align Alignment = Load->getAlign();
212   if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
213                                    &DT)) {
214     // It is not safe to load directly from the pointer, but we can still peek
215     // through gep offsets and check if it safe to load from a base address with
216     // updated alignment. If it is, we can shuffle the element(s) into place
217     // after loading.
218     unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
219     APInt Offset(OffsetBitWidth, 0);
220     SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset);
221 
222     // We want to shuffle the result down from a high element of a vector, so
223     // the offset must be positive.
224     if (Offset.isNegative())
225       return false;
226 
227     // The offset must be a multiple of the scalar element to shuffle cleanly
228     // in the element's size.
229     uint64_t ScalarSizeInBytes = ScalarSize / 8;
230     if (Offset.urem(ScalarSizeInBytes) != 0)
231       return false;
232 
233     // If we load MinVecNumElts, will our target element still be loaded?
234     OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
235     if (OffsetEltIndex >= MinVecNumElts)
236       return false;
237 
238     if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
239                                      &DT))
240       return false;
241 
242     // Update alignment with offset value. Note that the offset could be negated
243     // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
244     // negation does not change the result of the alignment calculation.
245     Alignment = commonAlignment(Alignment, Offset.getZExtValue());
246   }
247 
248   // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
249   // Use the greater of the alignment on the load or its source pointer.
250   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
251   Type *LoadTy = Load->getType();
252   unsigned AS = Load->getPointerAddressSpace();
253   InstructionCost OldCost =
254       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
255   APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
256   OldCost +=
257       TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
258                                    /* Insert */ true, HasExtract, CostKind);
259 
260   // New pattern: load VecPtr
261   InstructionCost NewCost =
262       TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
263   // Optionally, we are shuffling the loaded vector element(s) into place.
264   // For the mask set everything but element 0 to undef to prevent poison from
265   // propagating from the extra loaded memory. This will also optionally
266   // shrink/grow the vector from the loaded size to the output size.
267   // We assume this operation has no cost in codegen if there was no offset.
268   // Note that we could use freeze to avoid poison problems, but then we might
269   // still need a shuffle to change the vector size.
270   auto *Ty = cast<FixedVectorType>(I.getType());
271   unsigned OutputNumElts = Ty->getNumElements();
272   SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
273   assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
274   Mask[0] = OffsetEltIndex;
275   if (OffsetEltIndex)
276     NewCost +=
277         TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind);
278 
279   // We can aggressively convert to the vector form because the backend can
280   // invert this transform if it does not result in a performance win.
281   if (OldCost < NewCost || !NewCost.isValid())
282     return false;
283 
284   // It is safe and potentially profitable to load a vector directly:
285   // inselt undef, load Scalar, 0 --> load VecPtr
286   IRBuilder<> Builder(Load);
287   Value *CastedPtr =
288       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
289   Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
290   VecLd = Builder.CreateShuffleVector(VecLd, Mask);
291 
292   replaceValue(I, *VecLd);
293   ++NumVecLoad;
294   return true;
295 }
296 
297 /// If we are loading a vector and then inserting it into a larger vector with
298 /// undefined elements, try to load the larger vector and eliminate the insert.
299 /// This removes a shuffle in IR and may allow combining of other loaded values.
300 bool VectorCombine::widenSubvectorLoad(Instruction &I) {
301   // Match subvector insert of fixed vector.
302   auto *Shuf = cast<ShuffleVectorInst>(&I);
303   if (!Shuf->isIdentityWithPadding())
304     return false;
305 
306   // Allow a non-canonical shuffle mask that is choosing elements from op1.
307   unsigned NumOpElts =
308       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
309   unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
310     return M >= (int)(NumOpElts);
311   });
312 
313   auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
314   if (!canWidenLoad(Load, TTI))
315     return false;
316 
317   // We use minimal alignment (maximum flexibility) because we only care about
318   // the dereferenceable region. When calculating cost and creating a new op,
319   // we may use a larger value based on alignment attributes.
320   auto *Ty = cast<FixedVectorType>(I.getType());
321   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
322   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
323   Align Alignment = Load->getAlign();
324   if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT))
325     return false;
326 
327   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
328   Type *LoadTy = Load->getType();
329   unsigned AS = Load->getPointerAddressSpace();
330 
331   // Original pattern: insert_subvector (load PtrOp)
332   // This conservatively assumes that the cost of a subvector insert into an
333   // undef value is 0. We could add that cost if the cost model accurately
334   // reflects the real cost of that operation.
335   InstructionCost OldCost =
336       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
337 
338   // New pattern: load PtrOp
339   InstructionCost NewCost =
340       TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
341 
342   // We can aggressively convert to the vector form because the backend can
343   // invert this transform if it does not result in a performance win.
344   if (OldCost < NewCost || !NewCost.isValid())
345     return false;
346 
347   IRBuilder<> Builder(Load);
348   Value *CastedPtr =
349       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
350   Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
351   replaceValue(I, *VecLd);
352   ++NumVecLoad;
353   return true;
354 }
355 
356 /// Determine which, if any, of the inputs should be replaced by a shuffle
357 /// followed by extract from a different index.
358 ExtractElementInst *VectorCombine::getShuffleExtract(
359     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
360     unsigned PreferredExtractIndex = InvalidIndex) const {
361   auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
362   auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
363   assert(Index0C && Index1C && "Expected constant extract indexes");
364 
365   unsigned Index0 = Index0C->getZExtValue();
366   unsigned Index1 = Index1C->getZExtValue();
367 
368   // If the extract indexes are identical, no shuffle is needed.
369   if (Index0 == Index1)
370     return nullptr;
371 
372   Type *VecTy = Ext0->getVectorOperand()->getType();
373   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
374   InstructionCost Cost0 =
375       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
376   InstructionCost Cost1 =
377       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
378 
379   // If both costs are invalid no shuffle is needed
380   if (!Cost0.isValid() && !Cost1.isValid())
381     return nullptr;
382 
383   // We are extracting from 2 different indexes, so one operand must be shuffled
384   // before performing a vector operation and/or extract. The more expensive
385   // extract will be replaced by a shuffle.
386   if (Cost0 > Cost1)
387     return Ext0;
388   if (Cost1 > Cost0)
389     return Ext1;
390 
391   // If the costs are equal and there is a preferred extract index, shuffle the
392   // opposite operand.
393   if (PreferredExtractIndex == Index0)
394     return Ext1;
395   if (PreferredExtractIndex == Index1)
396     return Ext0;
397 
398   // Otherwise, replace the extract with the higher index.
399   return Index0 > Index1 ? Ext0 : Ext1;
400 }
401 
402 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
403 /// vector operation(s) followed by extract. Return true if the existing
404 /// instructions are cheaper than a vector alternative. Otherwise, return false
405 /// and if one of the extracts should be transformed to a shufflevector, set
406 /// \p ConvertToShuffle to that extract instruction.
407 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
408                                           ExtractElementInst *Ext1,
409                                           const Instruction &I,
410                                           ExtractElementInst *&ConvertToShuffle,
411                                           unsigned PreferredExtractIndex) {
412   auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
413   auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
414   assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
415 
416   unsigned Opcode = I.getOpcode();
417   Value *Ext0Src = Ext0->getVectorOperand();
418   Value *Ext1Src = Ext1->getVectorOperand();
419   Type *ScalarTy = Ext0->getType();
420   auto *VecTy = cast<VectorType>(Ext0Src->getType());
421   InstructionCost ScalarOpCost, VectorOpCost;
422 
423   // Get cost estimates for scalar and vector versions of the operation.
424   bool IsBinOp = Instruction::isBinaryOp(Opcode);
425   if (IsBinOp) {
426     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
427     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
428   } else {
429     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
430            "Expected a compare");
431     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
432     ScalarOpCost = TTI.getCmpSelInstrCost(
433         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
434     VectorOpCost = TTI.getCmpSelInstrCost(
435         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
436   }
437 
438   // Get cost estimates for the extract elements. These costs will factor into
439   // both sequences.
440   unsigned Ext0Index = Ext0IndexC->getZExtValue();
441   unsigned Ext1Index = Ext1IndexC->getZExtValue();
442 
443   InstructionCost Extract0Cost =
444       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
445   InstructionCost Extract1Cost =
446       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
447 
448   // A more expensive extract will always be replaced by a splat shuffle.
449   // For example, if Ext0 is more expensive:
450   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
451   // extelt (opcode (splat V0, Ext0), V1), Ext1
452   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
453   //       check the cost of creating a broadcast shuffle and shuffling both
454   //       operands to element 0.
455   unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
456   unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
457   InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
458 
459   // Extra uses of the extracts mean that we include those costs in the
460   // vector total because those instructions will not be eliminated.
461   InstructionCost OldCost, NewCost;
462   if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
463     // Handle a special case. If the 2 extracts are identical, adjust the
464     // formulas to account for that. The extra use charge allows for either the
465     // CSE'd pattern or an unoptimized form with identical values:
466     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
467     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
468                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
469     OldCost = CheapExtractCost + ScalarOpCost;
470     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
471   } else {
472     // Handle the general case. Each extract is actually a different value:
473     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
474     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
475     NewCost = VectorOpCost + CheapExtractCost +
476               !Ext0->hasOneUse() * Extract0Cost +
477               !Ext1->hasOneUse() * Extract1Cost;
478   }
479 
480   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
481   if (ConvertToShuffle) {
482     if (IsBinOp && DisableBinopExtractShuffle)
483       return true;
484 
485     // If we are extracting from 2 different indexes, then one operand must be
486     // shuffled before performing the vector operation. The shuffle mask is
487     // poison except for 1 lane that is being translated to the remaining
488     // extraction lane. Therefore, it is a splat shuffle. Ex:
489     // ShufMask = { poison, poison, 0, poison }
490     // TODO: The cost model has an option for a "broadcast" shuffle
491     //       (splat-from-element-0), but no option for a more general splat.
492     if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
493       SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
494                                    PoisonMaskElem);
495       ShuffleMask[BestInsIndex] = BestExtIndex;
496       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
497                                     VecTy, ShuffleMask, CostKind, 0, nullptr,
498                                     {ConvertToShuffle});
499     } else {
500       NewCost +=
501           TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
502                              {}, CostKind, 0, nullptr, {ConvertToShuffle});
503     }
504   }
505 
506   // Aggressively form a vector op if the cost is equal because the transform
507   // may enable further optimization.
508   // Codegen can reverse this transform (scalarize) if it was not profitable.
509   return OldCost < NewCost;
510 }
511 
512 /// Create a shuffle that translates (shifts) 1 element from the input vector
513 /// to a new element location.
514 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
515                                  unsigned NewIndex, IRBuilder<> &Builder) {
516   // The shuffle mask is poison except for 1 lane that is being translated
517   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
518   // ShufMask = { 2, poison, poison, poison }
519   auto *VecTy = cast<FixedVectorType>(Vec->getType());
520   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
521   ShufMask[NewIndex] = OldIndex;
522   return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
523 }
524 
525 /// Given an extract element instruction with constant index operand, shuffle
526 /// the source vector (shift the scalar element) to a NewIndex for extraction.
527 /// Return null if the input can be constant folded, so that we are not creating
528 /// unnecessary instructions.
529 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
530                                             unsigned NewIndex,
531                                             IRBuilder<> &Builder) {
532   // Shufflevectors can only be created for fixed-width vectors.
533   Value *X = ExtElt->getVectorOperand();
534   if (!isa<FixedVectorType>(X->getType()))
535     return nullptr;
536 
537   // If the extract can be constant-folded, this code is unsimplified. Defer
538   // to other passes to handle that.
539   Value *C = ExtElt->getIndexOperand();
540   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
541   if (isa<Constant>(X))
542     return nullptr;
543 
544   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
545                                    NewIndex, Builder);
546   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
547 }
548 
549 /// Try to reduce extract element costs by converting scalar compares to vector
550 /// compares followed by extract.
551 /// cmp (ext0 V0, C), (ext1 V1, C)
552 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
553                                   ExtractElementInst *Ext1, Instruction &I) {
554   assert(isa<CmpInst>(&I) && "Expected a compare");
555   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
556              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
557          "Expected matching constant extract indexes");
558 
559   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
560   ++NumVecCmp;
561   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
562   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
563   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
564   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
565   replaceValue(I, *NewExt);
566 }
567 
568 /// Try to reduce extract element costs by converting scalar binops to vector
569 /// binops followed by extract.
570 /// bo (ext0 V0, C), (ext1 V1, C)
571 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
572                                     ExtractElementInst *Ext1, Instruction &I) {
573   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
574   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
575              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
576          "Expected matching constant extract indexes");
577 
578   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
579   ++NumVecBO;
580   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
581   Value *VecBO =
582       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
583 
584   // All IR flags are safe to back-propagate because any potential poison
585   // created in unused vector elements is discarded by the extract.
586   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
587     VecBOInst->copyIRFlags(&I);
588 
589   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
590   replaceValue(I, *NewExt);
591 }
592 
593 /// Match an instruction with extracted vector operands.
594 bool VectorCombine::foldExtractExtract(Instruction &I) {
595   // It is not safe to transform things like div, urem, etc. because we may
596   // create undefined behavior when executing those on unknown vector elements.
597   if (!isSafeToSpeculativelyExecute(&I))
598     return false;
599 
600   Instruction *I0, *I1;
601   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
602   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
603       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
604     return false;
605 
606   Value *V0, *V1;
607   uint64_t C0, C1;
608   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
609       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
610       V0->getType() != V1->getType())
611     return false;
612 
613   // If the scalar value 'I' is going to be re-inserted into a vector, then try
614   // to create an extract to that same element. The extract/insert can be
615   // reduced to a "select shuffle".
616   // TODO: If we add a larger pattern match that starts from an insert, this
617   //       probably becomes unnecessary.
618   auto *Ext0 = cast<ExtractElementInst>(I0);
619   auto *Ext1 = cast<ExtractElementInst>(I1);
620   uint64_t InsertIndex = InvalidIndex;
621   if (I.hasOneUse())
622     match(I.user_back(),
623           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
624 
625   ExtractElementInst *ExtractToChange;
626   if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
627     return false;
628 
629   if (ExtractToChange) {
630     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
631     ExtractElementInst *NewExtract =
632         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
633     if (!NewExtract)
634       return false;
635     if (ExtractToChange == Ext0)
636       Ext0 = NewExtract;
637     else
638       Ext1 = NewExtract;
639   }
640 
641   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
642     foldExtExtCmp(Ext0, Ext1, I);
643   else
644     foldExtExtBinop(Ext0, Ext1, I);
645 
646   Worklist.push(Ext0);
647   Worklist.push(Ext1);
648   return true;
649 }
650 
651 /// Try to replace an extract + scalar fneg + insert with a vector fneg +
652 /// shuffle.
653 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
654   // Match an insert (op (extract)) pattern.
655   Value *DestVec;
656   uint64_t Index;
657   Instruction *FNeg;
658   if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
659                              m_ConstantInt(Index))))
660     return false;
661 
662   // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
663   Value *SrcVec;
664   Instruction *Extract;
665   if (!match(FNeg, m_FNeg(m_CombineAnd(
666                        m_Instruction(Extract),
667                        m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
668     return false;
669 
670   auto *VecTy = cast<FixedVectorType>(I.getType());
671   auto *ScalarTy = VecTy->getScalarType();
672   auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
673   if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
674     return false;
675 
676   // Ignore bogus insert/extract index.
677   unsigned NumElts = VecTy->getNumElements();
678   if (Index >= NumElts)
679     return false;
680 
681   // We are inserting the negated element into the same lane that we extracted
682   // from. This is equivalent to a select-shuffle that chooses all but the
683   // negated element from the destination vector.
684   SmallVector<int> Mask(NumElts);
685   std::iota(Mask.begin(), Mask.end(), 0);
686   Mask[Index] = Index + NumElts;
687   InstructionCost OldCost =
688       TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
689       TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
690 
691   // If the extract has one use, it will be eliminated, so count it in the
692   // original cost. If it has more than one use, ignore the cost because it will
693   // be the same before/after.
694   if (Extract->hasOneUse())
695     OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
696 
697   InstructionCost NewCost =
698       TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
699       TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind);
700 
701   bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
702   // If the lengths of the two vectors are not equal,
703   // we need to add a length-change vector. Add this cost.
704   SmallVector<int> SrcMask;
705   if (NeedLenChg) {
706     SrcMask.assign(NumElts, PoisonMaskElem);
707     SrcMask[Index] = Index;
708     NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
709                                   SrcVecTy, SrcMask, CostKind);
710   }
711 
712   if (NewCost > OldCost)
713     return false;
714 
715   Value *NewShuf;
716   // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
717   Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
718   if (NeedLenChg) {
719     // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
720     Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
721     NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask);
722   } else {
723     // shuffle DestVec, (fneg SrcVec), Mask
724     NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
725   }
726 
727   replaceValue(I, *NewShuf);
728   return true;
729 }
730 
731 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
732 /// destination type followed by shuffle. This can enable further transforms by
733 /// moving bitcasts or shuffles together.
734 bool VectorCombine::foldBitcastShuffle(Instruction &I) {
735   Value *V0, *V1;
736   ArrayRef<int> Mask;
737   if (!match(&I, m_BitCast(m_OneUse(
738                      m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
739     return false;
740 
741   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
742   // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
743   // mask for scalable type is a splat or not.
744   // 2) Disallow non-vector casts.
745   // TODO: We could allow any shuffle.
746   auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
747   auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
748   if (!DestTy || !SrcTy)
749     return false;
750 
751   unsigned DestEltSize = DestTy->getScalarSizeInBits();
752   unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
753   if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
754     return false;
755 
756   bool IsUnary = isa<UndefValue>(V1);
757 
758   // For binary shuffles, only fold bitcast(shuffle(X,Y))
759   // if it won't increase the number of bitcasts.
760   if (!IsUnary) {
761     auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
762     auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
763     if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
764         !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
765       return false;
766   }
767 
768   SmallVector<int, 16> NewMask;
769   if (DestEltSize <= SrcEltSize) {
770     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
771     // always be expanded to the equivalent form choosing narrower elements.
772     assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
773     unsigned ScaleFactor = SrcEltSize / DestEltSize;
774     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
775   } else {
776     // The bitcast is from narrow elements to wide elements. The shuffle mask
777     // must choose consecutive elements to allow casting first.
778     assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
779     unsigned ScaleFactor = DestEltSize / SrcEltSize;
780     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
781       return false;
782   }
783 
784   // Bitcast the shuffle src - keep its original width but using the destination
785   // scalar type.
786   unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
787   auto *NewShuffleTy =
788       FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
789   auto *OldShuffleTy =
790       FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
791   unsigned NumOps = IsUnary ? 1 : 2;
792 
793   // The new shuffle must not cost more than the old shuffle.
794   TargetTransformInfo::ShuffleKind SK =
795       IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
796               : TargetTransformInfo::SK_PermuteTwoSrc;
797 
798   InstructionCost NewCost =
799       TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) +
800       (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
801                                      TargetTransformInfo::CastContextHint::None,
802                                      CostKind));
803   InstructionCost OldCost =
804       TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) +
805       TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
806                            TargetTransformInfo::CastContextHint::None,
807                            CostKind);
808 
809   LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n  OldCost: "
810                     << OldCost << " vs NewCost: " << NewCost << "\n");
811 
812   if (NewCost > OldCost || !NewCost.isValid())
813     return false;
814 
815   // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
816   ++NumShufOfBitcast;
817   Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
818   Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
819   Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
820   replaceValue(I, *Shuf);
821   return true;
822 }
823 
824 /// VP Intrinsics whose vector operands are both splat values may be simplified
825 /// into the scalar version of the operation and the result splatted. This
826 /// can lead to scalarization down the line.
827 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
828   if (!isa<VPIntrinsic>(I))
829     return false;
830   VPIntrinsic &VPI = cast<VPIntrinsic>(I);
831   Value *Op0 = VPI.getArgOperand(0);
832   Value *Op1 = VPI.getArgOperand(1);
833 
834   if (!isSplatValue(Op0) || !isSplatValue(Op1))
835     return false;
836 
837   // Check getSplatValue early in this function, to avoid doing unnecessary
838   // work.
839   Value *ScalarOp0 = getSplatValue(Op0);
840   Value *ScalarOp1 = getSplatValue(Op1);
841   if (!ScalarOp0 || !ScalarOp1)
842     return false;
843 
844   // For the binary VP intrinsics supported here, the result on disabled lanes
845   // is a poison value. For now, only do this simplification if all lanes
846   // are active.
847   // TODO: Relax the condition that all lanes are active by using insertelement
848   // on inactive lanes.
849   auto IsAllTrueMask = [](Value *MaskVal) {
850     if (Value *SplattedVal = getSplatValue(MaskVal))
851       if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
852         return ConstValue->isAllOnesValue();
853     return false;
854   };
855   if (!IsAllTrueMask(VPI.getArgOperand(2)))
856     return false;
857 
858   // Check to make sure we support scalarization of the intrinsic
859   Intrinsic::ID IntrID = VPI.getIntrinsicID();
860   if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
861     return false;
862 
863   // Calculate cost of splatting both operands into vectors and the vector
864   // intrinsic
865   VectorType *VecTy = cast<VectorType>(VPI.getType());
866   SmallVector<int> Mask;
867   if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
868     Mask.resize(FVTy->getNumElements(), 0);
869   InstructionCost SplatCost =
870       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
871       TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask,
872                          CostKind);
873 
874   // Calculate the cost of the VP Intrinsic
875   SmallVector<Type *, 4> Args;
876   for (Value *V : VPI.args())
877     Args.push_back(V->getType());
878   IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
879   InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
880   InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
881 
882   // Determine scalar opcode
883   std::optional<unsigned> FunctionalOpcode =
884       VPI.getFunctionalOpcode();
885   std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
886   if (!FunctionalOpcode) {
887     ScalarIntrID = VPI.getFunctionalIntrinsicID();
888     if (!ScalarIntrID)
889       return false;
890   }
891 
892   // Calculate cost of scalarizing
893   InstructionCost ScalarOpCost = 0;
894   if (ScalarIntrID) {
895     IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
896     ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
897   } else {
898     ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
899                                               VecTy->getScalarType(), CostKind);
900   }
901 
902   // The existing splats may be kept around if other instructions use them.
903   InstructionCost CostToKeepSplats =
904       (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
905   InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
906 
907   LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
908                     << "\n");
909   LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
910                     << ", Cost of scalarizing:" << NewCost << "\n");
911 
912   // We want to scalarize unless the vector variant actually has lower cost.
913   if (OldCost < NewCost || !NewCost.isValid())
914     return false;
915 
916   // Scalarize the intrinsic
917   ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
918   Value *EVL = VPI.getArgOperand(3);
919 
920   // If the VP op might introduce UB or poison, we can scalarize it provided
921   // that we know the EVL > 0: If the EVL is zero, then the original VP op
922   // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
923   // scalarizing it.
924   bool SafeToSpeculate;
925   if (ScalarIntrID)
926     SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
927                           .hasFnAttr(Attribute::AttrKind::Speculatable);
928   else
929     SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
930         *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
931   if (!SafeToSpeculate &&
932       !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI)))
933     return false;
934 
935   Value *ScalarVal =
936       ScalarIntrID
937           ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
938                                     {ScalarOp0, ScalarOp1})
939           : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
940                                 ScalarOp0, ScalarOp1);
941 
942   replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
943   return true;
944 }
945 
946 /// Match a vector binop or compare instruction with at least one inserted
947 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
948 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
949   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
950   Value *Ins0, *Ins1;
951   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
952       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
953     return false;
954 
955   // Do not convert the vector condition of a vector select into a scalar
956   // condition. That may cause problems for codegen because of differences in
957   // boolean formats and register-file transfers.
958   // TODO: Can we account for that in the cost model?
959   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
960   if (IsCmp)
961     for (User *U : I.users())
962       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
963         return false;
964 
965   // Match against one or both scalar values being inserted into constant
966   // vectors:
967   // vec_op VecC0, (inselt VecC1, V1, Index)
968   // vec_op (inselt VecC0, V0, Index), VecC1
969   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
970   // TODO: Deal with mismatched index constants and variable indexes?
971   Constant *VecC0 = nullptr, *VecC1 = nullptr;
972   Value *V0 = nullptr, *V1 = nullptr;
973   uint64_t Index0 = 0, Index1 = 0;
974   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
975                                m_ConstantInt(Index0))) &&
976       !match(Ins0, m_Constant(VecC0)))
977     return false;
978   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
979                                m_ConstantInt(Index1))) &&
980       !match(Ins1, m_Constant(VecC1)))
981     return false;
982 
983   bool IsConst0 = !V0;
984   bool IsConst1 = !V1;
985   if (IsConst0 && IsConst1)
986     return false;
987   if (!IsConst0 && !IsConst1 && Index0 != Index1)
988     return false;
989 
990   auto *VecTy0 = cast<VectorType>(Ins0->getType());
991   auto *VecTy1 = cast<VectorType>(Ins1->getType());
992   if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
993       VecTy1->getElementCount().getKnownMinValue() <= Index1)
994     return false;
995 
996   // Bail for single insertion if it is a load.
997   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
998   auto *I0 = dyn_cast_or_null<Instruction>(V0);
999   auto *I1 = dyn_cast_or_null<Instruction>(V1);
1000   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
1001       (IsConst1 && I0 && I0->mayReadFromMemory()))
1002     return false;
1003 
1004   uint64_t Index = IsConst0 ? Index1 : Index0;
1005   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
1006   Type *VecTy = I.getType();
1007   assert(VecTy->isVectorTy() &&
1008          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
1009          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1010           ScalarTy->isPointerTy()) &&
1011          "Unexpected types for insert element into binop or cmp");
1012 
1013   unsigned Opcode = I.getOpcode();
1014   InstructionCost ScalarOpCost, VectorOpCost;
1015   if (IsCmp) {
1016     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
1017     ScalarOpCost = TTI.getCmpSelInstrCost(
1018         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1019     VectorOpCost = TTI.getCmpSelInstrCost(
1020         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1021   } else {
1022     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1023     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1024   }
1025 
1026   // Get cost estimate for the insert element. This cost will factor into
1027   // both sequences.
1028   InstructionCost InsertCost = TTI.getVectorInstrCost(
1029       Instruction::InsertElement, VecTy, CostKind, Index);
1030   InstructionCost OldCost =
1031       (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
1032   InstructionCost NewCost = ScalarOpCost + InsertCost +
1033                             (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
1034                             (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
1035 
1036   // We want to scalarize unless the vector variant actually has lower cost.
1037   if (OldCost < NewCost || !NewCost.isValid())
1038     return false;
1039 
1040   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1041   // inselt NewVecC, (scalar_op V0, V1), Index
1042   if (IsCmp)
1043     ++NumScalarCmp;
1044   else
1045     ++NumScalarBO;
1046 
1047   // For constant cases, extract the scalar element, this should constant fold.
1048   if (IsConst0)
1049     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1050   if (IsConst1)
1051     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1052 
1053   Value *Scalar =
1054       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
1055             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
1056 
1057   Scalar->setName(I.getName() + ".scalar");
1058 
1059   // All IR flags are safe to back-propagate. There is no potential for extra
1060   // poison to be created by the scalar instruction.
1061   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1062     ScalarInst->copyIRFlags(&I);
1063 
1064   // Fold the vector constants in the original vectors into a new base vector.
1065   Value *NewVecC =
1066       IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1067             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1068   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1069   replaceValue(I, *Insert);
1070   return true;
1071 }
1072 
1073 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1074 /// a vector into vector operations followed by extract. Note: The SLP pass
1075 /// may miss this pattern because of implementation problems.
1076 bool VectorCombine::foldExtractedCmps(Instruction &I) {
1077   auto *BI = dyn_cast<BinaryOperator>(&I);
1078 
1079   // We are looking for a scalar binop of booleans.
1080   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1081   if (!BI || !I.getType()->isIntegerTy(1))
1082     return false;
1083 
1084   // The compare predicates should match, and each compare should have a
1085   // constant operand.
1086   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1087   Instruction *I0, *I1;
1088   Constant *C0, *C1;
1089   CmpPredicate P0, P1;
1090   // FIXME: Use CmpPredicate::getMatching here.
1091   if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1092       !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) ||
1093       P0 != static_cast<CmpInst::Predicate>(P1))
1094     return false;
1095 
1096   // The compare operands must be extracts of the same vector with constant
1097   // extract indexes.
1098   Value *X;
1099   uint64_t Index0, Index1;
1100   if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1101       !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1102     return false;
1103 
1104   auto *Ext0 = cast<ExtractElementInst>(I0);
1105   auto *Ext1 = cast<ExtractElementInst>(I1);
1106   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1107   if (!ConvertToShuf)
1108     return false;
1109   assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1110          "Unknown ExtractElementInst");
1111 
1112   // The original scalar pattern is:
1113   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1114   CmpInst::Predicate Pred = P0;
1115   unsigned CmpOpcode =
1116       CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1117   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1118   if (!VecTy)
1119     return false;
1120 
1121   InstructionCost Ext0Cost =
1122       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1123   InstructionCost Ext1Cost =
1124       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1125   InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1126       CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1127       CostKind);
1128 
1129   InstructionCost OldCost =
1130       Ext0Cost + Ext1Cost + CmpCost * 2 +
1131       TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1132 
1133   // The proposed vector pattern is:
1134   // vcmp = cmp Pred X, VecC
1135   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1136   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1137   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1138   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
1139   InstructionCost NewCost = TTI.getCmpSelInstrCost(
1140       CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred,
1141       CostKind);
1142   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1143   ShufMask[CheapIndex] = ExpensiveIndex;
1144   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1145                                 ShufMask, CostKind);
1146   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1147   NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1148   NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1149   NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1150 
1151   // Aggressively form vector ops if the cost is equal because the transform
1152   // may enable further optimization.
1153   // Codegen can reverse this transform (scalarize) if it was not profitable.
1154   if (OldCost < NewCost || !NewCost.isValid())
1155     return false;
1156 
1157   // Create a vector constant from the 2 scalar constants.
1158   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1159                                    PoisonValue::get(VecTy->getElementType()));
1160   CmpC[Index0] = C0;
1161   CmpC[Index1] = C1;
1162   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1163   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1164   Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1165   Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1166   Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1167   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1168   replaceValue(I, *NewExt);
1169   ++NumVecCmpBO;
1170   return true;
1171 }
1172 
1173 // Check if memory loc modified between two instrs in the same BB
1174 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1175                                  BasicBlock::iterator End,
1176                                  const MemoryLocation &Loc, AAResults &AA) {
1177   unsigned NumScanned = 0;
1178   return std::any_of(Begin, End, [&](const Instruction &Instr) {
1179     return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1180            ++NumScanned > MaxInstrsToScan;
1181   });
1182 }
1183 
1184 namespace {
1185 /// Helper class to indicate whether a vector index can be safely scalarized and
1186 /// if a freeze needs to be inserted.
1187 class ScalarizationResult {
1188   enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1189 
1190   StatusTy Status;
1191   Value *ToFreeze;
1192 
1193   ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1194       : Status(Status), ToFreeze(ToFreeze) {}
1195 
1196 public:
1197   ScalarizationResult(const ScalarizationResult &Other) = default;
1198   ~ScalarizationResult() {
1199     assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1200   }
1201 
1202   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1203   static ScalarizationResult safe() { return {StatusTy::Safe}; }
1204   static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1205     return {StatusTy::SafeWithFreeze, ToFreeze};
1206   }
1207 
1208   /// Returns true if the index can be scalarize without requiring a freeze.
1209   bool isSafe() const { return Status == StatusTy::Safe; }
1210   /// Returns true if the index cannot be scalarized.
1211   bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1212   /// Returns true if the index can be scalarize, but requires inserting a
1213   /// freeze.
1214   bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1215 
1216   /// Reset the state of Unsafe and clear ToFreze if set.
1217   void discard() {
1218     ToFreeze = nullptr;
1219     Status = StatusTy::Unsafe;
1220   }
1221 
1222   /// Freeze the ToFreeze and update the use in \p User to use it.
1223   void freeze(IRBuilder<> &Builder, Instruction &UserI) {
1224     assert(isSafeWithFreeze() &&
1225            "should only be used when freezing is required");
1226     assert(is_contained(ToFreeze->users(), &UserI) &&
1227            "UserI must be a user of ToFreeze");
1228     IRBuilder<>::InsertPointGuard Guard(Builder);
1229     Builder.SetInsertPoint(cast<Instruction>(&UserI));
1230     Value *Frozen =
1231         Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1232     for (Use &U : make_early_inc_range((UserI.operands())))
1233       if (U.get() == ToFreeze)
1234         U.set(Frozen);
1235 
1236     ToFreeze = nullptr;
1237   }
1238 };
1239 } // namespace
1240 
1241 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1242 /// Idx. \p Idx must access a valid vector element.
1243 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1244                                               Instruction *CtxI,
1245                                               AssumptionCache &AC,
1246                                               const DominatorTree &DT) {
1247   // We do checks for both fixed vector types and scalable vector types.
1248   // This is the number of elements of fixed vector types,
1249   // or the minimum number of elements of scalable vector types.
1250   uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1251 
1252   if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1253     if (C->getValue().ult(NumElements))
1254       return ScalarizationResult::safe();
1255     return ScalarizationResult::unsafe();
1256   }
1257 
1258   unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1259   APInt Zero(IntWidth, 0);
1260   APInt MaxElts(IntWidth, NumElements);
1261   ConstantRange ValidIndices(Zero, MaxElts);
1262   ConstantRange IdxRange(IntWidth, true);
1263 
1264   if (isGuaranteedNotToBePoison(Idx, &AC)) {
1265     if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
1266                                                    true, &AC, CtxI, &DT)))
1267       return ScalarizationResult::safe();
1268     return ScalarizationResult::unsafe();
1269   }
1270 
1271   // If the index may be poison, check if we can insert a freeze before the
1272   // range of the index is restricted.
1273   Value *IdxBase;
1274   ConstantInt *CI;
1275   if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1276     IdxRange = IdxRange.binaryAnd(CI->getValue());
1277   } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1278     IdxRange = IdxRange.urem(CI->getValue());
1279   }
1280 
1281   if (ValidIndices.contains(IdxRange))
1282     return ScalarizationResult::safeWithFreeze(IdxBase);
1283   return ScalarizationResult::unsafe();
1284 }
1285 
1286 /// The memory operation on a vector of \p ScalarType had alignment of
1287 /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1288 /// alignment that will be valid for the memory operation on a single scalar
1289 /// element of the same type with index \p Idx.
1290 static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1291                                                 Type *ScalarType, Value *Idx,
1292                                                 const DataLayout &DL) {
1293   if (auto *C = dyn_cast<ConstantInt>(Idx))
1294     return commonAlignment(VectorAlignment,
1295                            C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1296   return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1297 }
1298 
1299 // Combine patterns like:
1300 //   %0 = load <4 x i32>, <4 x i32>* %a
1301 //   %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1302 //   store <4 x i32> %1, <4 x i32>* %a
1303 // to:
1304 //   %0 = bitcast <4 x i32>* %a to i32*
1305 //   %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1306 //   store i32 %b, i32* %1
1307 bool VectorCombine::foldSingleElementStore(Instruction &I) {
1308   auto *SI = cast<StoreInst>(&I);
1309   if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1310     return false;
1311 
1312   // TODO: Combine more complicated patterns (multiple insert) by referencing
1313   // TargetTransformInfo.
1314   Instruction *Source;
1315   Value *NewElement;
1316   Value *Idx;
1317   if (!match(SI->getValueOperand(),
1318              m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1319                          m_Value(Idx))))
1320     return false;
1321 
1322   if (auto *Load = dyn_cast<LoadInst>(Source)) {
1323     auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1324     Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1325     // Don't optimize for atomic/volatile load or store. Ensure memory is not
1326     // modified between, vector type matches store size, and index is inbounds.
1327     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1328         !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1329         SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1330       return false;
1331 
1332     auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1333     if (ScalarizableIdx.isUnsafe() ||
1334         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1335                              MemoryLocation::get(SI), AA))
1336       return false;
1337 
1338     if (ScalarizableIdx.isSafeWithFreeze())
1339       ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1340     Value *GEP = Builder.CreateInBoundsGEP(
1341         SI->getValueOperand()->getType(), SI->getPointerOperand(),
1342         {ConstantInt::get(Idx->getType(), 0), Idx});
1343     StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1344     NSI->copyMetadata(*SI);
1345     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1346         std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1347         *DL);
1348     NSI->setAlignment(ScalarOpAlignment);
1349     replaceValue(I, *NSI);
1350     eraseInstruction(I);
1351     return true;
1352   }
1353 
1354   return false;
1355 }
1356 
1357 /// Try to scalarize vector loads feeding extractelement instructions.
1358 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1359   Value *Ptr;
1360   if (!match(&I, m_Load(m_Value(Ptr))))
1361     return false;
1362 
1363   auto *VecTy = cast<VectorType>(I.getType());
1364   auto *LI = cast<LoadInst>(&I);
1365   if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1366     return false;
1367 
1368   InstructionCost OriginalCost =
1369       TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1370                           LI->getPointerAddressSpace(), CostKind);
1371   InstructionCost ScalarizedCost = 0;
1372 
1373   Instruction *LastCheckedInst = LI;
1374   unsigned NumInstChecked = 0;
1375   DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1376   auto FailureGuard = make_scope_exit([&]() {
1377     // If the transform is aborted, discard the ScalarizationResults.
1378     for (auto &Pair : NeedFreeze)
1379       Pair.second.discard();
1380   });
1381 
1382   // Check if all users of the load are extracts with no memory modifications
1383   // between the load and the extract. Compute the cost of both the original
1384   // code and the scalarized version.
1385   for (User *U : LI->users()) {
1386     auto *UI = dyn_cast<ExtractElementInst>(U);
1387     if (!UI || UI->getParent() != LI->getParent())
1388       return false;
1389 
1390     // Check if any instruction between the load and the extract may modify
1391     // memory.
1392     if (LastCheckedInst->comesBefore(UI)) {
1393       for (Instruction &I :
1394            make_range(std::next(LI->getIterator()), UI->getIterator())) {
1395         // Bail out if we reached the check limit or the instruction may write
1396         // to memory.
1397         if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1398           return false;
1399         NumInstChecked++;
1400       }
1401       LastCheckedInst = UI;
1402     }
1403 
1404     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
1405     if (ScalarIdx.isUnsafe())
1406       return false;
1407     if (ScalarIdx.isSafeWithFreeze()) {
1408       NeedFreeze.try_emplace(UI, ScalarIdx);
1409       ScalarIdx.discard();
1410     }
1411 
1412     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1413     OriginalCost +=
1414         TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1415                                Index ? Index->getZExtValue() : -1);
1416     ScalarizedCost +=
1417         TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1418                             Align(1), LI->getPointerAddressSpace(), CostKind);
1419     ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1420   }
1421 
1422   if (ScalarizedCost >= OriginalCost)
1423     return false;
1424 
1425   // Replace extracts with narrow scalar loads.
1426   for (User *U : LI->users()) {
1427     auto *EI = cast<ExtractElementInst>(U);
1428     Value *Idx = EI->getOperand(1);
1429 
1430     // Insert 'freeze' for poison indexes.
1431     auto It = NeedFreeze.find(EI);
1432     if (It != NeedFreeze.end())
1433       It->second.freeze(Builder, *cast<Instruction>(Idx));
1434 
1435     Builder.SetInsertPoint(EI);
1436     Value *GEP =
1437         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1438     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1439         VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1440 
1441     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1442         LI->getAlign(), VecTy->getElementType(), Idx, *DL);
1443     NewLoad->setAlignment(ScalarOpAlignment);
1444 
1445     replaceValue(*EI, *NewLoad);
1446   }
1447 
1448   FailureGuard.release();
1449   return true;
1450 }
1451 
1452 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1453 /// to "(bitcast (concat X, Y))"
1454 /// where X/Y are bitcasted from i1 mask vectors.
1455 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1456   Type *Ty = I.getType();
1457   if (!Ty->isIntegerTy())
1458     return false;
1459 
1460   // TODO: Add big endian test coverage
1461   if (DL->isBigEndian())
1462     return false;
1463 
1464   // Restrict to disjoint cases so the mask vectors aren't overlapping.
1465   Instruction *X, *Y;
1466   if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1467     return false;
1468 
1469   // Allow both sources to contain shl, to handle more generic pattern:
1470   // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1471   Value *SrcX;
1472   uint64_t ShAmtX = 0;
1473   if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1474       !match(X, m_OneUse(
1475                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1476                           m_ConstantInt(ShAmtX)))))
1477     return false;
1478 
1479   Value *SrcY;
1480   uint64_t ShAmtY = 0;
1481   if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1482       !match(Y, m_OneUse(
1483                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1484                           m_ConstantInt(ShAmtY)))))
1485     return false;
1486 
1487   // Canonicalize larger shift to the RHS.
1488   if (ShAmtX > ShAmtY) {
1489     std::swap(X, Y);
1490     std::swap(SrcX, SrcY);
1491     std::swap(ShAmtX, ShAmtY);
1492   }
1493 
1494   // Ensure both sources are matching vXi1 bool mask types, and that the shift
1495   // difference is the mask width so they can be easily concatenated together.
1496   uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1497   unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1498   unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1499   auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1500   if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1501       !MaskTy->getElementType()->isIntegerTy(1) ||
1502       MaskTy->getNumElements() != ShAmtDiff ||
1503       MaskTy->getNumElements() > (BitWidth / 2))
1504     return false;
1505 
1506   auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1507   auto *ConcatIntTy =
1508       Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1509   auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1510 
1511   SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1512   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1513 
1514   // TODO: Is it worth supporting multi use cases?
1515   InstructionCost OldCost = 0;
1516   OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1517   OldCost +=
1518       NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1519   OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1520                                       TTI::CastContextHint::None, CostKind);
1521   OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1522                                       TTI::CastContextHint::None, CostKind);
1523 
1524   InstructionCost NewCost = 0;
1525   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1526                                 ConcatMask, CostKind);
1527   NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1528                                   TTI::CastContextHint::None, CostKind);
1529   if (Ty != ConcatIntTy)
1530     NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1531                                     TTI::CastContextHint::None, CostKind);
1532   if (ShAmtX > 0)
1533     NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1534 
1535   LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
1536                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1537                     << "\n");
1538 
1539   if (NewCost > OldCost)
1540     return false;
1541 
1542   // Build bool mask concatenation, bitcast back to scalar integer, and perform
1543   // any residual zero-extension or shifting.
1544   Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1545   Worklist.pushValue(Concat);
1546 
1547   Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1548 
1549   if (Ty != ConcatIntTy) {
1550     Worklist.pushValue(Result);
1551     Result = Builder.CreateZExt(Result, Ty);
1552   }
1553 
1554   if (ShAmtX > 0) {
1555     Worklist.pushValue(Result);
1556     Result = Builder.CreateShl(Result, ShAmtX);
1557   }
1558 
1559   replaceValue(I, *Result);
1560   return true;
1561 }
1562 
1563 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1564 ///           -->  "binop (shuffle), (shuffle)".
1565 bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1566   BinaryOperator *BinOp;
1567   ArrayRef<int> OuterMask;
1568   if (!match(&I,
1569              m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
1570     return false;
1571 
1572   // Don't introduce poison into div/rem.
1573   if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
1574     return false;
1575 
1576   Value *Op00, *Op01;
1577   ArrayRef<int> Mask0;
1578   if (!match(BinOp->getOperand(0),
1579              m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
1580     return false;
1581 
1582   Value *Op10, *Op11;
1583   ArrayRef<int> Mask1;
1584   if (!match(BinOp->getOperand(1),
1585              m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
1586     return false;
1587 
1588   Instruction::BinaryOps Opcode = BinOp->getOpcode();
1589   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1590   auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
1591   auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
1592   auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
1593   if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1594     return false;
1595 
1596   unsigned NumSrcElts = BinOpTy->getNumElements();
1597 
1598   // Don't accept shuffles that reference the second operand in
1599   // div/rem or if its an undef arg.
1600   if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
1601       any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
1602     return false;
1603 
1604   // Merge outer / inner shuffles.
1605   SmallVector<int> NewMask0, NewMask1;
1606   for (int M : OuterMask) {
1607     if (M < 0 || M >= (int)NumSrcElts) {
1608       NewMask0.push_back(PoisonMaskElem);
1609       NewMask1.push_back(PoisonMaskElem);
1610     } else {
1611       NewMask0.push_back(Mask0[M]);
1612       NewMask1.push_back(Mask1[M]);
1613     }
1614   }
1615 
1616   // Try to merge shuffles across the binop if the new shuffles are not costly.
1617   InstructionCost OldCost =
1618       TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
1619       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1620                          OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
1621       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1622                          CostKind, 0, nullptr, {Op00, Op01},
1623                          cast<Instruction>(BinOp->getOperand(0))) +
1624       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1625                          CostKind, 0, nullptr, {Op10, Op11},
1626                          cast<Instruction>(BinOp->getOperand(1)));
1627 
1628   InstructionCost NewCost =
1629       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1630                          CostKind, 0, nullptr, {Op00, Op01}) +
1631       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1632                          CostKind, 0, nullptr, {Op10, Op11}) +
1633       TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
1634 
1635   LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
1636                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1637                     << "\n");
1638 
1639   // If costs are equal, still fold as we reduce instruction count.
1640   if (NewCost > OldCost)
1641     return false;
1642 
1643   Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1644   Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1645   Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1646 
1647   // Intersect flags from the old binops.
1648   if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1649     NewInst->copyIRFlags(BinOp);
1650 
1651   Worklist.pushValue(Shuf0);
1652   Worklist.pushValue(Shuf1);
1653   replaceValue(I, *NewBO);
1654   return true;
1655 }
1656 
1657 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1658 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
1659 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1660   ArrayRef<int> OldMask;
1661   Instruction *LHS, *RHS;
1662   if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)),
1663                            m_OneUse(m_Instruction(RHS)), m_Mask(OldMask))))
1664     return false;
1665 
1666   // TODO: Add support for addlike etc.
1667   if (LHS->getOpcode() != RHS->getOpcode())
1668     return false;
1669 
1670   Value *X, *Y, *Z, *W;
1671   bool IsCommutative = false;
1672   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1673   if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
1674       match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
1675     auto *BO = cast<BinaryOperator>(LHS);
1676     // Don't introduce poison into div/rem.
1677     if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
1678       return false;
1679     IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
1680   } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
1681              match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) {
1682     IsCommutative = cast<CmpInst>(LHS)->isCommutative();
1683   } else
1684     return false;
1685 
1686   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1687   auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
1688   auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
1689   if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
1690     return false;
1691 
1692   unsigned NumSrcElts = BinOpTy->getNumElements();
1693 
1694   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1695   if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
1696     std::swap(X, Y);
1697 
1698   auto ConvertToUnary = [NumSrcElts](int &M) {
1699     if (M >= (int)NumSrcElts)
1700       M -= NumSrcElts;
1701   };
1702 
1703   SmallVector<int> NewMask0(OldMask);
1704   TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
1705   if (X == Z) {
1706     llvm::for_each(NewMask0, ConvertToUnary);
1707     SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
1708     Z = PoisonValue::get(BinOpTy);
1709   }
1710 
1711   SmallVector<int> NewMask1(OldMask);
1712   TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
1713   if (Y == W) {
1714     llvm::for_each(NewMask1, ConvertToUnary);
1715     SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
1716     W = PoisonValue::get(BinOpTy);
1717   }
1718 
1719   // Try to replace a binop with a shuffle if the shuffle is not costly.
1720   InstructionCost OldCost =
1721       TTI.getInstructionCost(LHS, CostKind) +
1722       TTI.getInstructionCost(RHS, CostKind) +
1723       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy,
1724                          OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I);
1725 
1726   InstructionCost NewCost =
1727       TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
1728       TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W});
1729 
1730   if (Pred == CmpInst::BAD_ICMP_PREDICATE) {
1731     NewCost +=
1732         TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
1733   } else {
1734     auto *ShuffleCmpTy =
1735         FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
1736     NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
1737                                       ShuffleDstTy, Pred, CostKind);
1738   }
1739 
1740   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
1741                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1742                     << "\n");
1743 
1744   // If either shuffle will constant fold away, then fold for the same cost as
1745   // we will reduce the instruction count.
1746   bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) ||
1747                           (isa<Constant>(Y) && isa<Constant>(W));
1748   if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
1749     return false;
1750 
1751   Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
1752   Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
1753   Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE
1754                      ? Builder.CreateBinOp(
1755                            cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
1756                      : Builder.CreateCmp(Pred, Shuf0, Shuf1);
1757 
1758   // Intersect flags from the old binops.
1759   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1760     NewInst->copyIRFlags(LHS);
1761     NewInst->andIRFlags(RHS);
1762   }
1763 
1764   Worklist.pushValue(Shuf0);
1765   Worklist.pushValue(Shuf1);
1766   replaceValue(I, *NewBO);
1767   return true;
1768 }
1769 
1770 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
1771 /// into "castop (shuffle)".
1772 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1773   Value *V0, *V1;
1774   ArrayRef<int> OldMask;
1775   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
1776     return false;
1777 
1778   auto *C0 = dyn_cast<CastInst>(V0);
1779   auto *C1 = dyn_cast<CastInst>(V1);
1780   if (!C0 || !C1)
1781     return false;
1782 
1783   Instruction::CastOps Opcode = C0->getOpcode();
1784   if (C0->getSrcTy() != C1->getSrcTy())
1785     return false;
1786 
1787   // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
1788   if (Opcode != C1->getOpcode()) {
1789     if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
1790       Opcode = Instruction::SExt;
1791     else
1792       return false;
1793   }
1794 
1795   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1796   auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
1797   auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
1798   if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
1799     return false;
1800 
1801   unsigned NumSrcElts = CastSrcTy->getNumElements();
1802   unsigned NumDstElts = CastDstTy->getNumElements();
1803   assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
1804          "Only bitcasts expected to alter src/dst element counts");
1805 
1806   // Check for bitcasting of unscalable vector types.
1807   // e.g. <32 x i40> -> <40 x i32>
1808   if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
1809       (NumDstElts % NumSrcElts) != 0)
1810     return false;
1811 
1812   SmallVector<int, 16> NewMask;
1813   if (NumSrcElts >= NumDstElts) {
1814     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1815     // always be expanded to the equivalent form choosing narrower elements.
1816     assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
1817     unsigned ScaleFactor = NumSrcElts / NumDstElts;
1818     narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
1819   } else {
1820     // The bitcast is from narrow elements to wide elements. The shuffle mask
1821     // must choose consecutive elements to allow casting first.
1822     assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
1823     unsigned ScaleFactor = NumDstElts / NumSrcElts;
1824     if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
1825       return false;
1826   }
1827 
1828   auto *NewShuffleDstTy =
1829       FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
1830 
1831   // Try to replace a castop with a shuffle if the shuffle is not costly.
1832   InstructionCost CostC0 =
1833       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
1834                            TTI::CastContextHint::None, CostKind);
1835   InstructionCost CostC1 =
1836       TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
1837                            TTI::CastContextHint::None, CostKind);
1838   InstructionCost OldCost = CostC0 + CostC1;
1839   OldCost +=
1840       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
1841                          OldMask, CostKind, 0, nullptr, {}, &I);
1842 
1843   InstructionCost NewCost = TTI.getShuffleCost(
1844       TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
1845   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
1846                                   TTI::CastContextHint::None, CostKind);
1847   if (!C0->hasOneUse())
1848     NewCost += CostC0;
1849   if (!C1->hasOneUse())
1850     NewCost += CostC1;
1851 
1852   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
1853                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1854                     << "\n");
1855   if (NewCost > OldCost)
1856     return false;
1857 
1858   Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
1859                                             C1->getOperand(0), NewMask);
1860   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
1861 
1862   // Intersect flags from the old casts.
1863   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
1864     NewInst->copyIRFlags(C0);
1865     NewInst->andIRFlags(C1);
1866   }
1867 
1868   Worklist.pushValue(Shuf);
1869   replaceValue(I, *Cast);
1870   return true;
1871 }
1872 
1873 /// Try to convert any of:
1874 /// "shuffle (shuffle x, undef), (shuffle y, undef)"
1875 /// "shuffle (shuffle x, undef), y"
1876 /// "shuffle x, (shuffle y, undef)"
1877 /// into "shuffle x, y".
1878 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1879   ArrayRef<int> OuterMask;
1880   Value *OuterV0, *OuterV1;
1881   if (!match(&I,
1882              m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
1883     return false;
1884 
1885   ArrayRef<int> InnerMask0, InnerMask1;
1886   Value *V0 = nullptr, *V1 = nullptr;
1887   UndefValue *U0 = nullptr, *U1 = nullptr;
1888   bool Match0 = match(
1889       OuterV0, m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0)));
1890   bool Match1 = match(
1891       OuterV1, m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1)));
1892   if (!Match0 && !Match1)
1893     return false;
1894 
1895   V0 = Match0 ? V0 : OuterV0;
1896   V1 = Match1 ? V1 : OuterV1;
1897   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1898   auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
1899   auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
1900   if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1901       V0->getType() != V1->getType())
1902     return false;
1903 
1904   unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
1905   unsigned NumImmElts = ShuffleImmTy->getNumElements();
1906 
1907   // Bail if either inner masks reference a RHS undef arg.
1908   if ((Match0 && !isa<PoisonValue>(U0) &&
1909        any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
1910       (Match1 && !isa<PoisonValue>(U1) &&
1911        any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
1912     return false;
1913 
1914   // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1915   SmallVector<int, 16> NewMask(OuterMask);
1916   for (int &M : NewMask) {
1917     if (0 <= M && M < (int)NumImmElts) {
1918       if (Match0)
1919         M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1920     } else if (M >= (int)NumImmElts) {
1921       if (Match1) {
1922         if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
1923           M = PoisonMaskElem;
1924         else
1925           M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1926       }
1927     }
1928   }
1929 
1930   // Have we folded to an Identity shuffle?
1931   if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
1932     replaceValue(I, *V0);
1933     return true;
1934   }
1935 
1936   // Try to merge the shuffles if the new shuffle is not costly.
1937   InstructionCost InnerCost0 = 0;
1938   if (Match0)
1939     InnerCost0 = TTI.getShuffleCost(
1940         TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask0,
1941         CostKind, 0, nullptr, {V0, U0}, cast<ShuffleVectorInst>(OuterV0));
1942 
1943   InstructionCost InnerCost1 = 0;
1944   if (Match1)
1945     InnerCost1 = TTI.getShuffleCost(
1946         TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask1,
1947         CostKind, 0, nullptr, {V1, U1}, cast<ShuffleVectorInst>(OuterV1));
1948 
1949   InstructionCost OuterCost = TTI.getShuffleCost(
1950       TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind,
1951       0, nullptr, {OuterV0, OuterV1}, &I);
1952 
1953   InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
1954 
1955   InstructionCost NewCost =
1956       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy,
1957                          NewMask, CostKind, 0, nullptr, {V0, V1});
1958   if (!OuterV0->hasOneUse())
1959     NewCost += InnerCost0;
1960   if (!OuterV1->hasOneUse())
1961     NewCost += InnerCost1;
1962 
1963   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
1964                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1965                     << "\n");
1966   if (NewCost > OldCost)
1967     return false;
1968 
1969   // Clear unused sources to poison.
1970   if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
1971     V0 = PoisonValue::get(ShuffleSrcTy);
1972   if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
1973     V1 = PoisonValue::get(ShuffleSrcTy);
1974 
1975   Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
1976   replaceValue(I, *Shuf);
1977   return true;
1978 }
1979 
1980 /// Try to convert
1981 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
1982 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
1983   Value *V0, *V1;
1984   ArrayRef<int> OldMask;
1985   if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
1986                            m_Mask(OldMask))))
1987     return false;
1988 
1989   auto *II0 = dyn_cast<IntrinsicInst>(V0);
1990   auto *II1 = dyn_cast<IntrinsicInst>(V1);
1991   if (!II0 || !II1)
1992     return false;
1993 
1994   Intrinsic::ID IID = II0->getIntrinsicID();
1995   if (IID != II1->getIntrinsicID())
1996     return false;
1997 
1998   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1999   auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
2000   if (!ShuffleDstTy || !II0Ty)
2001     return false;
2002 
2003   if (!isTriviallyVectorizable(IID))
2004     return false;
2005 
2006   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2007     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) &&
2008         II0->getArgOperand(I) != II1->getArgOperand(I))
2009       return false;
2010 
2011   InstructionCost OldCost =
2012       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
2013       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) +
2014       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask,
2015                          CostKind, 0, nullptr, {II0, II1}, &I);
2016 
2017   SmallVector<Type *> NewArgsTy;
2018   InstructionCost NewCost = 0;
2019   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2020     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2021       NewArgsTy.push_back(II0->getArgOperand(I)->getType());
2022     } else {
2023       auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
2024       NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(),
2025                                                VecTy->getNumElements() * 2));
2026       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
2027                                     VecTy, OldMask, CostKind);
2028     }
2029   IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
2030   NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
2031 
2032   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
2033                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
2034                     << "\n");
2035 
2036   if (NewCost > OldCost)
2037     return false;
2038 
2039   SmallVector<Value *> NewArgs;
2040   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2041     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2042       NewArgs.push_back(II0->getArgOperand(I));
2043     } else {
2044       Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
2045                                                 II1->getArgOperand(I), OldMask);
2046       NewArgs.push_back(Shuf);
2047       Worklist.pushValue(Shuf);
2048     }
2049   Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
2050 
2051   // Intersect flags from the old intrinsics.
2052   if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
2053     NewInst->copyIRFlags(II0);
2054     NewInst->andIRFlags(II1);
2055   }
2056 
2057   replaceValue(I, *NewIntrinsic);
2058   return true;
2059 }
2060 
2061 using InstLane = std::pair<Use *, int>;
2062 
2063 static InstLane lookThroughShuffles(Use *U, int Lane) {
2064   while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) {
2065     unsigned NumElts =
2066         cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
2067     int M = SV->getMaskValue(Lane);
2068     if (M < 0)
2069       return {nullptr, PoisonMaskElem};
2070     if (static_cast<unsigned>(M) < NumElts) {
2071       U = &SV->getOperandUse(0);
2072       Lane = M;
2073     } else {
2074       U = &SV->getOperandUse(1);
2075       Lane = M - NumElts;
2076     }
2077   }
2078   return InstLane{U, Lane};
2079 }
2080 
2081 static SmallVector<InstLane>
2082 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
2083   SmallVector<InstLane> NItem;
2084   for (InstLane IL : Item) {
2085     auto [U, Lane] = IL;
2086     InstLane OpLane =
2087         U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op),
2088                                 Lane)
2089           : InstLane{nullptr, PoisonMaskElem};
2090     NItem.emplace_back(OpLane);
2091   }
2092   return NItem;
2093 }
2094 
2095 /// Detect concat of multiple values into a vector
2096 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
2097                          const TargetTransformInfo &TTI) {
2098   auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
2099   unsigned NumElts = Ty->getNumElements();
2100   if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
2101     return false;
2102 
2103   // Check that the concat is free, usually meaning that the type will be split
2104   // during legalization.
2105   SmallVector<int, 16> ConcatMask(NumElts * 2);
2106   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2107   if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0)
2108     return false;
2109 
2110   unsigned NumSlices = Item.size() / NumElts;
2111   // Currently we generate a tree of shuffles for the concats, which limits us
2112   // to a power2.
2113   if (!isPowerOf2_32(NumSlices))
2114     return false;
2115   for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
2116     Use *SliceV = Item[Slice * NumElts].first;
2117     if (!SliceV || SliceV->get()->getType() != Ty)
2118       return false;
2119     for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
2120       auto [V, Lane] = Item[Slice * NumElts + Elt];
2121       if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
2122         return false;
2123     }
2124   }
2125   return true;
2126 }
2127 
2128 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
2129                                   const SmallPtrSet<Use *, 4> &IdentityLeafs,
2130                                   const SmallPtrSet<Use *, 4> &SplatLeafs,
2131                                   const SmallPtrSet<Use *, 4> &ConcatLeafs,
2132                                   IRBuilder<> &Builder,
2133                                   const TargetTransformInfo *TTI) {
2134   auto [FrontU, FrontLane] = Item.front();
2135 
2136   if (IdentityLeafs.contains(FrontU)) {
2137     return FrontU->get();
2138   }
2139   if (SplatLeafs.contains(FrontU)) {
2140     SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
2141     return Builder.CreateShuffleVector(FrontU->get(), Mask);
2142   }
2143   if (ConcatLeafs.contains(FrontU)) {
2144     unsigned NumElts =
2145         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
2146     SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
2147     for (unsigned S = 0; S < Values.size(); ++S)
2148       Values[S] = Item[S * NumElts].first->get();
2149 
2150     while (Values.size() > 1) {
2151       NumElts *= 2;
2152       SmallVector<int, 16> Mask(NumElts, 0);
2153       std::iota(Mask.begin(), Mask.end(), 0);
2154       SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
2155       for (unsigned S = 0; S < NewValues.size(); ++S)
2156         NewValues[S] =
2157             Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
2158       Values = NewValues;
2159     }
2160     return Values[0];
2161   }
2162 
2163   auto *I = cast<Instruction>(FrontU->get());
2164   auto *II = dyn_cast<IntrinsicInst>(I);
2165   unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
2166   SmallVector<Value *> Ops(NumOps);
2167   for (unsigned Idx = 0; Idx < NumOps; Idx++) {
2168     if (II &&
2169         isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
2170       Ops[Idx] = II->getOperand(Idx);
2171       continue;
2172     }
2173     Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
2174                                    Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
2175                                    Builder, TTI);
2176   }
2177 
2178   SmallVector<Value *, 8> ValueList;
2179   for (const auto &Lane : Item)
2180     if (Lane.first)
2181       ValueList.push_back(Lane.first->get());
2182 
2183   Type *DstTy =
2184       FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
2185   if (auto *BI = dyn_cast<BinaryOperator>(I)) {
2186     auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
2187                                       Ops[0], Ops[1]);
2188     propagateIRFlags(Value, ValueList);
2189     return Value;
2190   }
2191   if (auto *CI = dyn_cast<CmpInst>(I)) {
2192     auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
2193     propagateIRFlags(Value, ValueList);
2194     return Value;
2195   }
2196   if (auto *SI = dyn_cast<SelectInst>(I)) {
2197     auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
2198     propagateIRFlags(Value, ValueList);
2199     return Value;
2200   }
2201   if (auto *CI = dyn_cast<CastInst>(I)) {
2202     auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
2203                                      Ops[0], DstTy);
2204     propagateIRFlags(Value, ValueList);
2205     return Value;
2206   }
2207   if (II) {
2208     auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
2209     propagateIRFlags(Value, ValueList);
2210     return Value;
2211   }
2212   assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
2213   auto *Value =
2214       Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
2215   propagateIRFlags(Value, ValueList);
2216   return Value;
2217 }
2218 
2219 // Starting from a shuffle, look up through operands tracking the shuffled index
2220 // of each lane. If we can simplify away the shuffles to identities then
2221 // do so.
2222 bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
2223   auto *Ty = dyn_cast<FixedVectorType>(I.getType());
2224   if (!Ty || I.use_empty())
2225     return false;
2226 
2227   SmallVector<InstLane> Start(Ty->getNumElements());
2228   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
2229     Start[M] = lookThroughShuffles(&*I.use_begin(), M);
2230 
2231   SmallVector<SmallVector<InstLane>> Worklist;
2232   Worklist.push_back(Start);
2233   SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
2234   unsigned NumVisited = 0;
2235 
2236   while (!Worklist.empty()) {
2237     if (++NumVisited > MaxInstrsToScan)
2238       return false;
2239 
2240     SmallVector<InstLane> Item = Worklist.pop_back_val();
2241     auto [FrontU, FrontLane] = Item.front();
2242 
2243     // If we found an undef first lane then bail out to keep things simple.
2244     if (!FrontU)
2245       return false;
2246 
2247     // Helper to peek through bitcasts to the same value.
2248     auto IsEquiv = [&](Value *X, Value *Y) {
2249       return X->getType() == Y->getType() &&
2250              peekThroughBitcasts(X) == peekThroughBitcasts(Y);
2251     };
2252 
2253     // Look for an identity value.
2254     if (FrontLane == 0 &&
2255         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() ==
2256             Ty->getNumElements() &&
2257         all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
2258           Value *FrontV = Item.front().first->get();
2259           return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
2260                                       E.value().second == (int)E.index());
2261         })) {
2262       IdentityLeafs.insert(FrontU);
2263       continue;
2264     }
2265     // Look for constants, for the moment only supporting constant splats.
2266     if (auto *C = dyn_cast<Constant>(FrontU);
2267         C && C->getSplatValue() &&
2268         all_of(drop_begin(Item), [Item](InstLane &IL) {
2269           Value *FrontV = Item.front().first->get();
2270           Use *U = IL.first;
2271           return !U || (isa<Constant>(U->get()) &&
2272                         cast<Constant>(U->get())->getSplatValue() ==
2273                             cast<Constant>(FrontV)->getSplatValue());
2274         })) {
2275       SplatLeafs.insert(FrontU);
2276       continue;
2277     }
2278     // Look for a splat value.
2279     if (all_of(drop_begin(Item), [Item](InstLane &IL) {
2280           auto [FrontU, FrontLane] = Item.front();
2281           auto [U, Lane] = IL;
2282           return !U || (U->get() == FrontU->get() && Lane == FrontLane);
2283         })) {
2284       SplatLeafs.insert(FrontU);
2285       continue;
2286     }
2287 
2288     // We need each element to be the same type of value, and check that each
2289     // element has a single use.
2290     auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
2291       Value *FrontV = Item.front().first->get();
2292       if (!IL.first)
2293         return true;
2294       Value *V = IL.first->get();
2295       if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
2296         return false;
2297       if (V->getValueID() != FrontV->getValueID())
2298         return false;
2299       if (auto *CI = dyn_cast<CmpInst>(V))
2300         if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
2301           return false;
2302       if (auto *CI = dyn_cast<CastInst>(V))
2303         if (CI->getSrcTy()->getScalarType() !=
2304             cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
2305           return false;
2306       if (auto *SI = dyn_cast<SelectInst>(V))
2307         if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
2308             SI->getOperand(0)->getType() !=
2309                 cast<SelectInst>(FrontV)->getOperand(0)->getType())
2310           return false;
2311       if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
2312         return false;
2313       auto *II = dyn_cast<IntrinsicInst>(V);
2314       return !II || (isa<IntrinsicInst>(FrontV) &&
2315                      II->getIntrinsicID() ==
2316                          cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
2317                      !II->hasOperandBundles());
2318     };
2319     if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
2320       // Check the operator is one that we support.
2321       if (isa<BinaryOperator, CmpInst>(FrontU)) {
2322         //  We exclude div/rem in case they hit UB from poison lanes.
2323         if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
2324             BO && BO->isIntDivRem())
2325           return false;
2326         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2327         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2328         continue;
2329       } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
2330                      FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) {
2331         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2332         continue;
2333       } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) {
2334         // TODO: Handle vector widening/narrowing bitcasts.
2335         auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
2336         auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
2337         if (DstTy && SrcTy &&
2338             SrcTy->getNumElements() == DstTy->getNumElements()) {
2339           Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2340           continue;
2341         }
2342       } else if (isa<SelectInst>(FrontU)) {
2343         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2344         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2345         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
2346         continue;
2347       } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
2348                  II && isTriviallyVectorizable(II->getIntrinsicID()) &&
2349                  !II->hasOperandBundles()) {
2350         for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
2351           if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
2352                                                  &TTI)) {
2353             if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
2354                   Value *FrontV = Item.front().first->get();
2355                   Use *U = IL.first;
2356                   return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
2357                                 cast<Instruction>(FrontV)->getOperand(Op));
2358                 }))
2359               return false;
2360             continue;
2361           }
2362           Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
2363         }
2364         continue;
2365       }
2366     }
2367 
2368     if (isFreeConcat(Item, CostKind, TTI)) {
2369       ConcatLeafs.insert(FrontU);
2370       continue;
2371     }
2372 
2373     return false;
2374   }
2375 
2376   if (NumVisited <= 1)
2377     return false;
2378 
2379   LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
2380 
2381   // If we got this far, we know the shuffles are superfluous and can be
2382   // removed. Scan through again and generate the new tree of instructions.
2383   Builder.SetInsertPoint(&I);
2384   Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
2385                                  ConcatLeafs, Builder, &TTI);
2386   replaceValue(I, *V);
2387   return true;
2388 }
2389 
2390 /// Given a commutative reduction, the order of the input lanes does not alter
2391 /// the results. We can use this to remove certain shuffles feeding the
2392 /// reduction, removing the need to shuffle at all.
2393 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2394   auto *II = dyn_cast<IntrinsicInst>(&I);
2395   if (!II)
2396     return false;
2397   switch (II->getIntrinsicID()) {
2398   case Intrinsic::vector_reduce_add:
2399   case Intrinsic::vector_reduce_mul:
2400   case Intrinsic::vector_reduce_and:
2401   case Intrinsic::vector_reduce_or:
2402   case Intrinsic::vector_reduce_xor:
2403   case Intrinsic::vector_reduce_smin:
2404   case Intrinsic::vector_reduce_smax:
2405   case Intrinsic::vector_reduce_umin:
2406   case Intrinsic::vector_reduce_umax:
2407     break;
2408   default:
2409     return false;
2410   }
2411 
2412   // Find all the inputs when looking through operations that do not alter the
2413   // lane order (binops, for example). Currently we look for a single shuffle,
2414   // and can ignore splat values.
2415   std::queue<Value *> Worklist;
2416   SmallPtrSet<Value *, 4> Visited;
2417   ShuffleVectorInst *Shuffle = nullptr;
2418   if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
2419     Worklist.push(Op);
2420 
2421   while (!Worklist.empty()) {
2422     Value *CV = Worklist.front();
2423     Worklist.pop();
2424     if (Visited.contains(CV))
2425       continue;
2426 
2427     // Splats don't change the order, so can be safely ignored.
2428     if (isSplatValue(CV))
2429       continue;
2430 
2431     Visited.insert(CV);
2432 
2433     if (auto *CI = dyn_cast<Instruction>(CV)) {
2434       if (CI->isBinaryOp()) {
2435         for (auto *Op : CI->operand_values())
2436           Worklist.push(Op);
2437         continue;
2438       } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
2439         if (Shuffle && Shuffle != SV)
2440           return false;
2441         Shuffle = SV;
2442         continue;
2443       }
2444     }
2445 
2446     // Anything else is currently an unknown node.
2447     return false;
2448   }
2449 
2450   if (!Shuffle)
2451     return false;
2452 
2453   // Check all uses of the binary ops and shuffles are also included in the
2454   // lane-invariant operations (Visited should be the list of lanewise
2455   // instructions, including the shuffle that we found).
2456   for (auto *V : Visited)
2457     for (auto *U : V->users())
2458       if (!Visited.contains(U) && U != &I)
2459         return false;
2460 
2461   FixedVectorType *VecType =
2462       dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
2463   if (!VecType)
2464     return false;
2465   FixedVectorType *ShuffleInputType =
2466       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
2467   if (!ShuffleInputType)
2468     return false;
2469   unsigned NumInputElts = ShuffleInputType->getNumElements();
2470 
2471   // Find the mask from sorting the lanes into order. This is most likely to
2472   // become a identity or concat mask. Undef elements are pushed to the end.
2473   SmallVector<int> ConcatMask;
2474   Shuffle->getShuffleMask(ConcatMask);
2475   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
2476   // In the case of a truncating shuffle it's possible for the mask
2477   // to have an index greater than the size of the resulting vector.
2478   // This requires special handling.
2479   bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
2480   bool UsesSecondVec =
2481       any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
2482 
2483   FixedVectorType *VecTyForCost =
2484       (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
2485   InstructionCost OldCost = TTI.getShuffleCost(
2486       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2487       VecTyForCost, Shuffle->getShuffleMask(), CostKind);
2488   InstructionCost NewCost = TTI.getShuffleCost(
2489       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2490       VecTyForCost, ConcatMask, CostKind);
2491 
2492   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
2493                     << "\n");
2494   LLVM_DEBUG(dbgs() << "  OldCost: " << OldCost << " vs NewCost: " << NewCost
2495                     << "\n");
2496   if (NewCost < OldCost) {
2497     Builder.SetInsertPoint(Shuffle);
2498     Value *NewShuffle = Builder.CreateShuffleVector(
2499         Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
2500     LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
2501     replaceValue(*Shuffle, *NewShuffle);
2502   }
2503 
2504   // See if we can re-use foldSelectShuffle, getting it to reduce the size of
2505   // the shuffle into a nicer order, as it can ignore the order of the shuffles.
2506   return foldSelectShuffle(*Shuffle, true);
2507 }
2508 
2509 /// Determine if its more efficient to fold:
2510 ///   reduce(trunc(x)) -> trunc(reduce(x)).
2511 ///   reduce(sext(x))  -> sext(reduce(x)).
2512 ///   reduce(zext(x))  -> zext(reduce(x)).
2513 bool VectorCombine::foldCastFromReductions(Instruction &I) {
2514   auto *II = dyn_cast<IntrinsicInst>(&I);
2515   if (!II)
2516     return false;
2517 
2518   bool TruncOnly = false;
2519   Intrinsic::ID IID = II->getIntrinsicID();
2520   switch (IID) {
2521   case Intrinsic::vector_reduce_add:
2522   case Intrinsic::vector_reduce_mul:
2523     TruncOnly = true;
2524     break;
2525   case Intrinsic::vector_reduce_and:
2526   case Intrinsic::vector_reduce_or:
2527   case Intrinsic::vector_reduce_xor:
2528     break;
2529   default:
2530     return false;
2531   }
2532 
2533   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
2534   Value *ReductionSrc = I.getOperand(0);
2535 
2536   Value *Src;
2537   if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
2538       (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
2539     return false;
2540 
2541   auto CastOpc =
2542       (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
2543 
2544   auto *SrcTy = cast<VectorType>(Src->getType());
2545   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
2546   Type *ResultTy = I.getType();
2547 
2548   InstructionCost OldCost = TTI.getArithmeticReductionCost(
2549       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2550   OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
2551                                   TTI::CastContextHint::None, CostKind,
2552                                   cast<CastInst>(ReductionSrc));
2553   InstructionCost NewCost =
2554       TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
2555                                      CostKind) +
2556       TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
2557                            TTI::CastContextHint::None, CostKind);
2558 
2559   if (OldCost <= NewCost || !NewCost.isValid())
2560     return false;
2561 
2562   Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
2563                                                 II->getIntrinsicID(), {Src});
2564   Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
2565   replaceValue(I, *NewCast);
2566   return true;
2567 }
2568 
2569 /// This method looks for groups of shuffles acting on binops, of the form:
2570 ///  %x = shuffle ...
2571 ///  %y = shuffle ...
2572 ///  %a = binop %x, %y
2573 ///  %b = binop %x, %y
2574 ///  shuffle %a, %b, selectmask
2575 /// We may, especially if the shuffle is wider than legal, be able to convert
2576 /// the shuffle to a form where only parts of a and b need to be computed. On
2577 /// architectures with no obvious "select" shuffle, this can reduce the total
2578 /// number of operations if the target reports them as cheaper.
2579 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
2580   auto *SVI = cast<ShuffleVectorInst>(&I);
2581   auto *VT = cast<FixedVectorType>(I.getType());
2582   auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
2583   auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
2584   if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
2585       VT != Op0->getType())
2586     return false;
2587 
2588   auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
2589   auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
2590   auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
2591   auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
2592   SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
2593   auto checkSVNonOpUses = [&](Instruction *I) {
2594     if (!I || I->getOperand(0)->getType() != VT)
2595       return true;
2596     return any_of(I->users(), [&](User *U) {
2597       return U != Op0 && U != Op1 &&
2598              !(isa<ShuffleVectorInst>(U) &&
2599                (InputShuffles.contains(cast<Instruction>(U)) ||
2600                 isInstructionTriviallyDead(cast<Instruction>(U))));
2601     });
2602   };
2603   if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
2604       checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
2605     return false;
2606 
2607   // Collect all the uses that are shuffles that we can transform together. We
2608   // may not have a single shuffle, but a group that can all be transformed
2609   // together profitably.
2610   SmallVector<ShuffleVectorInst *> Shuffles;
2611   auto collectShuffles = [&](Instruction *I) {
2612     for (auto *U : I->users()) {
2613       auto *SV = dyn_cast<ShuffleVectorInst>(U);
2614       if (!SV || SV->getType() != VT)
2615         return false;
2616       if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
2617           (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
2618         return false;
2619       if (!llvm::is_contained(Shuffles, SV))
2620         Shuffles.push_back(SV);
2621     }
2622     return true;
2623   };
2624   if (!collectShuffles(Op0) || !collectShuffles(Op1))
2625     return false;
2626   // From a reduction, we need to be processing a single shuffle, otherwise the
2627   // other uses will not be lane-invariant.
2628   if (FromReduction && Shuffles.size() > 1)
2629     return false;
2630 
2631   // Add any shuffle uses for the shuffles we have found, to include them in our
2632   // cost calculations.
2633   if (!FromReduction) {
2634     for (ShuffleVectorInst *SV : Shuffles) {
2635       for (auto *U : SV->users()) {
2636         ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
2637         if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
2638           Shuffles.push_back(SSV);
2639       }
2640     }
2641   }
2642 
2643   // For each of the output shuffles, we try to sort all the first vector
2644   // elements to the beginning, followed by the second array elements at the
2645   // end. If the binops are legalized to smaller vectors, this may reduce total
2646   // number of binops. We compute the ReconstructMask mask needed to convert
2647   // back to the original lane order.
2648   SmallVector<std::pair<int, int>> V1, V2;
2649   SmallVector<SmallVector<int>> OrigReconstructMasks;
2650   int MaxV1Elt = 0, MaxV2Elt = 0;
2651   unsigned NumElts = VT->getNumElements();
2652   for (ShuffleVectorInst *SVN : Shuffles) {
2653     SmallVector<int> Mask;
2654     SVN->getShuffleMask(Mask);
2655 
2656     // Check the operands are the same as the original, or reversed (in which
2657     // case we need to commute the mask).
2658     Value *SVOp0 = SVN->getOperand(0);
2659     Value *SVOp1 = SVN->getOperand(1);
2660     if (isa<UndefValue>(SVOp1)) {
2661       auto *SSV = cast<ShuffleVectorInst>(SVOp0);
2662       SVOp0 = SSV->getOperand(0);
2663       SVOp1 = SSV->getOperand(1);
2664       for (unsigned I = 0, E = Mask.size(); I != E; I++) {
2665         if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
2666           return false;
2667         Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
2668       }
2669     }
2670     if (SVOp0 == Op1 && SVOp1 == Op0) {
2671       std::swap(SVOp0, SVOp1);
2672       ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
2673     }
2674     if (SVOp0 != Op0 || SVOp1 != Op1)
2675       return false;
2676 
2677     // Calculate the reconstruction mask for this shuffle, as the mask needed to
2678     // take the packed values from Op0/Op1 and reconstructing to the original
2679     // order.
2680     SmallVector<int> ReconstructMask;
2681     for (unsigned I = 0; I < Mask.size(); I++) {
2682       if (Mask[I] < 0) {
2683         ReconstructMask.push_back(-1);
2684       } else if (Mask[I] < static_cast<int>(NumElts)) {
2685         MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
2686         auto It = find_if(V1, [&](const std::pair<int, int> &A) {
2687           return Mask[I] == A.first;
2688         });
2689         if (It != V1.end())
2690           ReconstructMask.push_back(It - V1.begin());
2691         else {
2692           ReconstructMask.push_back(V1.size());
2693           V1.emplace_back(Mask[I], V1.size());
2694         }
2695       } else {
2696         MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
2697         auto It = find_if(V2, [&](const std::pair<int, int> &A) {
2698           return Mask[I] - static_cast<int>(NumElts) == A.first;
2699         });
2700         if (It != V2.end())
2701           ReconstructMask.push_back(NumElts + It - V2.begin());
2702         else {
2703           ReconstructMask.push_back(NumElts + V2.size());
2704           V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
2705         }
2706       }
2707     }
2708 
2709     // For reductions, we know that the lane ordering out doesn't alter the
2710     // result. In-order can help simplify the shuffle away.
2711     if (FromReduction)
2712       sort(ReconstructMask);
2713     OrigReconstructMasks.push_back(std::move(ReconstructMask));
2714   }
2715 
2716   // If the Maximum element used from V1 and V2 are not larger than the new
2717   // vectors, the vectors are already packes and performing the optimization
2718   // again will likely not help any further. This also prevents us from getting
2719   // stuck in a cycle in case the costs do not also rule it out.
2720   if (V1.empty() || V2.empty() ||
2721       (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
2722        MaxV2Elt == static_cast<int>(V2.size()) - 1))
2723     return false;
2724 
2725   // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
2726   // shuffle of another shuffle, or not a shuffle (that is treated like a
2727   // identity shuffle).
2728   auto GetBaseMaskValue = [&](Instruction *I, int M) {
2729     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2730     if (!SV)
2731       return M;
2732     if (isa<UndefValue>(SV->getOperand(1)))
2733       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2734         if (InputShuffles.contains(SSV))
2735           return SSV->getMaskValue(SV->getMaskValue(M));
2736     return SV->getMaskValue(M);
2737   };
2738 
2739   // Attempt to sort the inputs my ascending mask values to make simpler input
2740   // shuffles and push complex shuffles down to the uses. We sort on the first
2741   // of the two input shuffle orders, to try and get at least one input into a
2742   // nice order.
2743   auto SortBase = [&](Instruction *A, std::pair<int, int> X,
2744                       std::pair<int, int> Y) {
2745     int MXA = GetBaseMaskValue(A, X.first);
2746     int MYA = GetBaseMaskValue(A, Y.first);
2747     return MXA < MYA;
2748   };
2749   stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
2750     return SortBase(SVI0A, A, B);
2751   });
2752   stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
2753     return SortBase(SVI1A, A, B);
2754   });
2755   // Calculate our ReconstructMasks from the OrigReconstructMasks and the
2756   // modified order of the input shuffles.
2757   SmallVector<SmallVector<int>> ReconstructMasks;
2758   for (const auto &Mask : OrigReconstructMasks) {
2759     SmallVector<int> ReconstructMask;
2760     for (int M : Mask) {
2761       auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
2762         auto It = find_if(V, [M](auto A) { return A.second == M; });
2763         assert(It != V.end() && "Expected all entries in Mask");
2764         return std::distance(V.begin(), It);
2765       };
2766       if (M < 0)
2767         ReconstructMask.push_back(-1);
2768       else if (M < static_cast<int>(NumElts)) {
2769         ReconstructMask.push_back(FindIndex(V1, M));
2770       } else {
2771         ReconstructMask.push_back(NumElts + FindIndex(V2, M));
2772       }
2773     }
2774     ReconstructMasks.push_back(std::move(ReconstructMask));
2775   }
2776 
2777   // Calculate the masks needed for the new input shuffles, which get padded
2778   // with undef
2779   SmallVector<int> V1A, V1B, V2A, V2B;
2780   for (unsigned I = 0; I < V1.size(); I++) {
2781     V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
2782     V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
2783   }
2784   for (unsigned I = 0; I < V2.size(); I++) {
2785     V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
2786     V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
2787   }
2788   while (V1A.size() < NumElts) {
2789     V1A.push_back(PoisonMaskElem);
2790     V1B.push_back(PoisonMaskElem);
2791   }
2792   while (V2A.size() < NumElts) {
2793     V2A.push_back(PoisonMaskElem);
2794     V2B.push_back(PoisonMaskElem);
2795   }
2796 
2797   auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
2798     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2799     if (!SV)
2800       return C;
2801     return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
2802                                       ? TTI::SK_PermuteSingleSrc
2803                                       : TTI::SK_PermuteTwoSrc,
2804                                   VT, SV->getShuffleMask(), CostKind);
2805   };
2806   auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
2807     return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind);
2808   };
2809 
2810   // Get the costs of the shuffles + binops before and after with the new
2811   // shuffle masks.
2812   InstructionCost CostBefore =
2813       TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
2814       TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
2815   CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
2816                                 InstructionCost(0), AddShuffleCost);
2817   CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
2818                                 InstructionCost(0), AddShuffleCost);
2819 
2820   // The new binops will be unused for lanes past the used shuffle lengths.
2821   // These types attempt to get the correct cost for that from the target.
2822   FixedVectorType *Op0SmallVT =
2823       FixedVectorType::get(VT->getScalarType(), V1.size());
2824   FixedVectorType *Op1SmallVT =
2825       FixedVectorType::get(VT->getScalarType(), V2.size());
2826   InstructionCost CostAfter =
2827       TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
2828       TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
2829   CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
2830                                InstructionCost(0), AddShuffleMaskCost);
2831   std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
2832   CostAfter +=
2833       std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
2834                       InstructionCost(0), AddShuffleMaskCost);
2835 
2836   LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
2837   LLVM_DEBUG(dbgs() << "  CostBefore: " << CostBefore
2838                     << " vs CostAfter: " << CostAfter << "\n");
2839   if (CostBefore <= CostAfter)
2840     return false;
2841 
2842   // The cost model has passed, create the new instructions.
2843   auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
2844     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2845     if (!SV)
2846       return I;
2847     if (isa<UndefValue>(SV->getOperand(1)))
2848       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2849         if (InputShuffles.contains(SSV))
2850           return SSV->getOperand(Op);
2851     return SV->getOperand(Op);
2852   };
2853   Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
2854   Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
2855                                              GetShuffleOperand(SVI0A, 1), V1A);
2856   Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
2857   Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
2858                                              GetShuffleOperand(SVI0B, 1), V1B);
2859   Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
2860   Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
2861                                              GetShuffleOperand(SVI1A, 1), V2A);
2862   Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
2863   Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
2864                                              GetShuffleOperand(SVI1B, 1), V2B);
2865   Builder.SetInsertPoint(Op0);
2866   Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
2867                                     NSV0A, NSV0B);
2868   if (auto *I = dyn_cast<Instruction>(NOp0))
2869     I->copyIRFlags(Op0, true);
2870   Builder.SetInsertPoint(Op1);
2871   Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
2872                                     NSV1A, NSV1B);
2873   if (auto *I = dyn_cast<Instruction>(NOp1))
2874     I->copyIRFlags(Op1, true);
2875 
2876   for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
2877     Builder.SetInsertPoint(Shuffles[S]);
2878     Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
2879     replaceValue(*Shuffles[S], *NSV);
2880   }
2881 
2882   Worklist.pushValue(NSV0A);
2883   Worklist.pushValue(NSV0B);
2884   Worklist.pushValue(NSV1A);
2885   Worklist.pushValue(NSV1B);
2886   for (auto *S : Shuffles)
2887     Worklist.add(S);
2888   return true;
2889 }
2890 
2891 /// Check if instruction depends on ZExt and this ZExt can be moved after the
2892 /// instruction. Move ZExt if it is profitable. For example:
2893 ///     logic(zext(x),y) -> zext(logic(x,trunc(y)))
2894 ///     lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
2895 /// Cost model calculations takes into account if zext(x) has other users and
2896 /// whether it can be propagated through them too.
2897 bool VectorCombine::shrinkType(Instruction &I) {
2898   Value *ZExted, *OtherOperand;
2899   if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
2900                                   m_Value(OtherOperand))) &&
2901       !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
2902     return false;
2903 
2904   Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
2905 
2906   auto *BigTy = cast<FixedVectorType>(I.getType());
2907   auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
2908   unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
2909 
2910   if (I.getOpcode() == Instruction::LShr) {
2911     // Check that the shift amount is less than the number of bits in the
2912     // smaller type. Otherwise, the smaller lshr will return a poison value.
2913     KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
2914     if (ShAmtKB.getMaxValue().uge(BW))
2915       return false;
2916   } else {
2917     // Check that the expression overall uses at most the same number of bits as
2918     // ZExted
2919     KnownBits KB = computeKnownBits(&I, *DL);
2920     if (KB.countMaxActiveBits() > BW)
2921       return false;
2922   }
2923 
2924   // Calculate costs of leaving current IR as it is and moving ZExt operation
2925   // later, along with adding truncates if needed
2926   InstructionCost ZExtCost = TTI.getCastInstrCost(
2927       Instruction::ZExt, BigTy, SmallTy,
2928       TargetTransformInfo::CastContextHint::None, CostKind);
2929   InstructionCost CurrentCost = ZExtCost;
2930   InstructionCost ShrinkCost = 0;
2931 
2932   // Calculate total cost and check that we can propagate through all ZExt users
2933   for (User *U : ZExtOperand->users()) {
2934     auto *UI = cast<Instruction>(U);
2935     if (UI == &I) {
2936       CurrentCost +=
2937           TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2938       ShrinkCost +=
2939           TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2940       ShrinkCost += ZExtCost;
2941       continue;
2942     }
2943 
2944     if (!Instruction::isBinaryOp(UI->getOpcode()))
2945       return false;
2946 
2947     // Check if we can propagate ZExt through its other users
2948     KnownBits KB = computeKnownBits(UI, *DL);
2949     if (KB.countMaxActiveBits() > BW)
2950       return false;
2951 
2952     CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2953     ShrinkCost +=
2954         TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2955     ShrinkCost += ZExtCost;
2956   }
2957 
2958   // If the other instruction operand is not a constant, we'll need to
2959   // generate a truncate instruction. So we have to adjust cost
2960   if (!isa<Constant>(OtherOperand))
2961     ShrinkCost += TTI.getCastInstrCost(
2962         Instruction::Trunc, SmallTy, BigTy,
2963         TargetTransformInfo::CastContextHint::None, CostKind);
2964 
2965   // If the cost of shrinking types and leaving the IR is the same, we'll lean
2966   // towards modifying the IR because shrinking opens opportunities for other
2967   // shrinking optimisations.
2968   if (ShrinkCost > CurrentCost)
2969     return false;
2970 
2971   Builder.SetInsertPoint(&I);
2972   Value *Op0 = ZExted;
2973   Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
2974   // Keep the order of operands the same
2975   if (I.getOperand(0) == OtherOperand)
2976     std::swap(Op0, Op1);
2977   Value *NewBinOp =
2978       Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
2979   cast<Instruction>(NewBinOp)->copyIRFlags(&I);
2980   cast<Instruction>(NewBinOp)->copyMetadata(I);
2981   Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
2982   replaceValue(I, *NewZExtr);
2983   return true;
2984 }
2985 
2986 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
2987 /// shuffle (DstVec, SrcVec, Mask)
2988 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
2989   Value *DstVec, *SrcVec;
2990   uint64_t ExtIdx, InsIdx;
2991   if (!match(&I,
2992              m_InsertElt(m_Value(DstVec),
2993                          m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
2994                          m_ConstantInt(InsIdx))))
2995     return false;
2996 
2997   auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
2998   if (!VecTy || SrcVec->getType() != VecTy)
2999     return false;
3000 
3001   unsigned NumElts = VecTy->getNumElements();
3002   if (ExtIdx >= NumElts || InsIdx >= NumElts)
3003     return false;
3004 
3005   SmallVector<int> Mask(NumElts, 0);
3006   std::iota(Mask.begin(), Mask.end(), 0);
3007   Mask[InsIdx] = ExtIdx + NumElts;
3008   // Cost
3009   auto *Ins = cast<InsertElementInst>(&I);
3010   auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
3011 
3012   InstructionCost OldCost =
3013       TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx) +
3014       TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx);
3015 
3016   InstructionCost NewCost = TTI.getShuffleCost(
3017       TargetTransformInfo::SK_PermuteTwoSrc, VecTy, Mask, CostKind);
3018   if (!Ext->hasOneUse())
3019     NewCost += TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx);
3020 
3021   LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I
3022                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
3023                     << "\n");
3024 
3025   if (OldCost < NewCost)
3026     return false;
3027 
3028   // Canonicalize undef param to RHS to help further folds.
3029   if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
3030     ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
3031     std::swap(DstVec, SrcVec);
3032   }
3033 
3034   Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
3035   replaceValue(I, *Shuf);
3036 
3037   return true;
3038 }
3039 
3040 /// This is the entry point for all transforms. Pass manager differences are
3041 /// handled in the callers of this function.
3042 bool VectorCombine::run() {
3043   if (DisableVectorCombine)
3044     return false;
3045 
3046   // Don't attempt vectorization if the target does not support vectors.
3047   if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
3048     return false;
3049 
3050   LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
3051 
3052   bool MadeChange = false;
3053   auto FoldInst = [this, &MadeChange](Instruction &I) {
3054     Builder.SetInsertPoint(&I);
3055     bool IsVectorType = isa<VectorType>(I.getType());
3056     bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
3057     auto Opcode = I.getOpcode();
3058 
3059     LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
3060 
3061     // These folds should be beneficial regardless of when this pass is run
3062     // in the optimization pipeline.
3063     // The type checking is for run-time efficiency. We can avoid wasting time
3064     // dispatching to folding functions if there's no chance of matching.
3065     if (IsFixedVectorType) {
3066       switch (Opcode) {
3067       case Instruction::InsertElement:
3068         MadeChange |= vectorizeLoadInsert(I);
3069         break;
3070       case Instruction::ShuffleVector:
3071         MadeChange |= widenSubvectorLoad(I);
3072         break;
3073       default:
3074         break;
3075       }
3076     }
3077 
3078     // This transform works with scalable and fixed vectors
3079     // TODO: Identify and allow other scalable transforms
3080     if (IsVectorType) {
3081       MadeChange |= scalarizeBinopOrCmp(I);
3082       MadeChange |= scalarizeLoadExtract(I);
3083       MadeChange |= scalarizeVPIntrinsic(I);
3084     }
3085 
3086     if (Opcode == Instruction::Store)
3087       MadeChange |= foldSingleElementStore(I);
3088 
3089     // If this is an early pipeline invocation of this pass, we are done.
3090     if (TryEarlyFoldsOnly)
3091       return;
3092 
3093     // Otherwise, try folds that improve codegen but may interfere with
3094     // early IR canonicalizations.
3095     // The type checking is for run-time efficiency. We can avoid wasting time
3096     // dispatching to folding functions if there's no chance of matching.
3097     if (IsFixedVectorType) {
3098       switch (Opcode) {
3099       case Instruction::InsertElement:
3100         MadeChange |= foldInsExtFNeg(I);
3101         MadeChange |= foldInsExtVectorToShuffle(I);
3102         break;
3103       case Instruction::ShuffleVector:
3104         MadeChange |= foldPermuteOfBinops(I);
3105         MadeChange |= foldShuffleOfBinops(I);
3106         MadeChange |= foldShuffleOfCastops(I);
3107         MadeChange |= foldShuffleOfShuffles(I);
3108         MadeChange |= foldShuffleOfIntrinsics(I);
3109         MadeChange |= foldSelectShuffle(I);
3110         MadeChange |= foldShuffleToIdentity(I);
3111         break;
3112       case Instruction::BitCast:
3113         MadeChange |= foldBitcastShuffle(I);
3114         break;
3115       default:
3116         MadeChange |= shrinkType(I);
3117         break;
3118       }
3119     } else {
3120       switch (Opcode) {
3121       case Instruction::Call:
3122         MadeChange |= foldShuffleFromReductions(I);
3123         MadeChange |= foldCastFromReductions(I);
3124         break;
3125       case Instruction::ICmp:
3126       case Instruction::FCmp:
3127         MadeChange |= foldExtractExtract(I);
3128         break;
3129       case Instruction::Or:
3130         MadeChange |= foldConcatOfBoolMasks(I);
3131         [[fallthrough]];
3132       default:
3133         if (Instruction::isBinaryOp(Opcode)) {
3134           MadeChange |= foldExtractExtract(I);
3135           MadeChange |= foldExtractedCmps(I);
3136         }
3137         break;
3138       }
3139     }
3140   };
3141 
3142   for (BasicBlock &BB : F) {
3143     // Ignore unreachable basic blocks.
3144     if (!DT.isReachableFromEntry(&BB))
3145       continue;
3146     // Use early increment range so that we can erase instructions in loop.
3147     for (Instruction &I : make_early_inc_range(BB)) {
3148       if (I.isDebugOrPseudoInst())
3149         continue;
3150       FoldInst(I);
3151     }
3152   }
3153 
3154   while (!Worklist.isEmpty()) {
3155     Instruction *I = Worklist.removeOne();
3156     if (!I)
3157       continue;
3158 
3159     if (isInstructionTriviallyDead(I)) {
3160       eraseInstruction(*I);
3161       continue;
3162     }
3163 
3164     FoldInst(*I);
3165   }
3166 
3167   return MadeChange;
3168 }
3169 
3170 PreservedAnalyses VectorCombinePass::run(Function &F,
3171                                          FunctionAnalysisManager &FAM) {
3172   auto &AC = FAM.getResult<AssumptionAnalysis>(F);
3173   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
3174   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
3175   AAResults &AA = FAM.getResult<AAManager>(F);
3176   const DataLayout *DL = &F.getDataLayout();
3177   VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
3178                          TryEarlyFoldsOnly);
3179   if (!Combiner.run())
3180     return PreservedAnalyses::all();
3181   PreservedAnalyses PA;
3182   PA.preserveSet<CFGAnalyses>();
3183   return PA;
3184 }
3185