xref: /llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp (revision 82b5bda42c00179fcf101e3ab2c591bda94ada9a)
1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/BasicAliasAnalysis.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/Loads.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/VectorUtils.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/PatternMatch.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include "llvm/Transforms/Utils/LoopUtils.h"
34 #include <numeric>
35 #include <queue>
36 
37 #define DEBUG_TYPE "vector-combine"
38 #include "llvm/Transforms/Utils/InstructionWorklist.h"
39 
40 using namespace llvm;
41 using namespace llvm::PatternMatch;
42 
43 STATISTIC(NumVecLoad, "Number of vector loads formed");
44 STATISTIC(NumVecCmp, "Number of vector compares formed");
45 STATISTIC(NumVecBO, "Number of vector binops formed");
46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
48 STATISTIC(NumScalarBO, "Number of scalar binops formed");
49 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
50 
51 static cl::opt<bool> DisableVectorCombine(
52     "disable-vector-combine", cl::init(false), cl::Hidden,
53     cl::desc("Disable all vector combine transforms"));
54 
55 static cl::opt<bool> DisableBinopExtractShuffle(
56     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
57     cl::desc("Disable binop extract to shuffle transforms"));
58 
59 static cl::opt<unsigned> MaxInstrsToScan(
60     "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
61     cl::desc("Max number of instructions to scan for vector combining."));
62 
63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
64 
65 namespace {
66 class VectorCombine {
67 public:
68   VectorCombine(Function &F, const TargetTransformInfo &TTI,
69                 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
70                 const DataLayout *DL, TTI::TargetCostKind CostKind,
71                 bool TryEarlyFoldsOnly)
72       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL),
73         CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
74 
75   bool run();
76 
77 private:
78   Function &F;
79   IRBuilder<> Builder;
80   const TargetTransformInfo &TTI;
81   const DominatorTree &DT;
82   AAResults &AA;
83   AssumptionCache &AC;
84   const DataLayout *DL;
85   TTI::TargetCostKind CostKind;
86 
87   /// If true, only perform beneficial early IR transforms. Do not introduce new
88   /// vector operations.
89   bool TryEarlyFoldsOnly;
90 
91   InstructionWorklist Worklist;
92 
93   // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
94   //       parameter. That should be updated to specific sub-classes because the
95   //       run loop was changed to dispatch on opcode.
96   bool vectorizeLoadInsert(Instruction &I);
97   bool widenSubvectorLoad(Instruction &I);
98   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
99                                         ExtractElementInst *Ext1,
100                                         unsigned PreferredExtractIndex) const;
101   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
102                              const Instruction &I,
103                              ExtractElementInst *&ConvertToShuffle,
104                              unsigned PreferredExtractIndex);
105   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
106                      Instruction &I);
107   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
108                        Instruction &I);
109   bool foldExtractExtract(Instruction &I);
110   bool foldInsExtFNeg(Instruction &I);
111   bool foldInsExtVectorToShuffle(Instruction &I);
112   bool foldBitcastShuffle(Instruction &I);
113   bool scalarizeBinopOrCmp(Instruction &I);
114   bool scalarizeVPIntrinsic(Instruction &I);
115   bool foldExtractedCmps(Instruction &I);
116   bool foldSingleElementStore(Instruction &I);
117   bool scalarizeLoadExtract(Instruction &I);
118   bool foldConcatOfBoolMasks(Instruction &I);
119   bool foldPermuteOfBinops(Instruction &I);
120   bool foldShuffleOfBinops(Instruction &I);
121   bool foldShuffleOfCastops(Instruction &I);
122   bool foldShuffleOfShuffles(Instruction &I);
123   bool foldShuffleOfIntrinsics(Instruction &I);
124   bool foldShuffleToIdentity(Instruction &I);
125   bool foldShuffleFromReductions(Instruction &I);
126   bool foldCastFromReductions(Instruction &I);
127   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
128   bool shrinkType(Instruction &I);
129 
130   void replaceValue(Value &Old, Value &New) {
131     Old.replaceAllUsesWith(&New);
132     if (auto *NewI = dyn_cast<Instruction>(&New)) {
133       New.takeName(&Old);
134       Worklist.pushUsersToWorkList(*NewI);
135       Worklist.pushValue(NewI);
136     }
137     Worklist.pushValue(&Old);
138   }
139 
140   void eraseInstruction(Instruction &I) {
141     LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
142     for (Value *Op : I.operands())
143       Worklist.pushValue(Op);
144     Worklist.remove(&I);
145     I.eraseFromParent();
146   }
147 };
148 } // namespace
149 
150 /// Return the source operand of a potentially bitcasted value. If there is no
151 /// bitcast, return the input value itself.
152 static Value *peekThroughBitcasts(Value *V) {
153   while (auto *BitCast = dyn_cast<BitCastInst>(V))
154     V = BitCast->getOperand(0);
155   return V;
156 }
157 
158 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
159   // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
160   // The widened load may load data from dirty regions or create data races
161   // non-existent in the source.
162   if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
163       Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
164       mustSuppressSpeculation(*Load))
165     return false;
166 
167   // We are potentially transforming byte-sized (8-bit) memory accesses, so make
168   // sure we have all of our type-based constraints in place for this target.
169   Type *ScalarTy = Load->getType()->getScalarType();
170   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
171   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
172   if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
173       ScalarSize % 8 != 0)
174     return false;
175 
176   return true;
177 }
178 
179 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
180   // Match insert into fixed vector of scalar value.
181   // TODO: Handle non-zero insert index.
182   Value *Scalar;
183   if (!match(&I,
184              m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt())))
185     return false;
186 
187   // Optionally match an extract from another vector.
188   Value *X;
189   bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
190   if (!HasExtract)
191     X = Scalar;
192 
193   auto *Load = dyn_cast<LoadInst>(X);
194   if (!canWidenLoad(Load, TTI))
195     return false;
196 
197   Type *ScalarTy = Scalar->getType();
198   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
199   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
200 
201   // Check safety of replacing the scalar load with a larger vector load.
202   // We use minimal alignment (maximum flexibility) because we only care about
203   // the dereferenceable region. When calculating cost and creating a new op,
204   // we may use a larger value based on alignment attributes.
205   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
206   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
207 
208   unsigned MinVecNumElts = MinVectorSize / ScalarSize;
209   auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
210   unsigned OffsetEltIndex = 0;
211   Align Alignment = Load->getAlign();
212   if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
213                                    &DT)) {
214     // It is not safe to load directly from the pointer, but we can still peek
215     // through gep offsets and check if it safe to load from a base address with
216     // updated alignment. If it is, we can shuffle the element(s) into place
217     // after loading.
218     unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
219     APInt Offset(OffsetBitWidth, 0);
220     SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset);
221 
222     // We want to shuffle the result down from a high element of a vector, so
223     // the offset must be positive.
224     if (Offset.isNegative())
225       return false;
226 
227     // The offset must be a multiple of the scalar element to shuffle cleanly
228     // in the element's size.
229     uint64_t ScalarSizeInBytes = ScalarSize / 8;
230     if (Offset.urem(ScalarSizeInBytes) != 0)
231       return false;
232 
233     // If we load MinVecNumElts, will our target element still be loaded?
234     OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
235     if (OffsetEltIndex >= MinVecNumElts)
236       return false;
237 
238     if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
239                                      &DT))
240       return false;
241 
242     // Update alignment with offset value. Note that the offset could be negated
243     // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
244     // negation does not change the result of the alignment calculation.
245     Alignment = commonAlignment(Alignment, Offset.getZExtValue());
246   }
247 
248   // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
249   // Use the greater of the alignment on the load or its source pointer.
250   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
251   Type *LoadTy = Load->getType();
252   unsigned AS = Load->getPointerAddressSpace();
253   InstructionCost OldCost =
254       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
255   APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
256   OldCost +=
257       TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
258                                    /* Insert */ true, HasExtract, CostKind);
259 
260   // New pattern: load VecPtr
261   InstructionCost NewCost =
262       TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
263   // Optionally, we are shuffling the loaded vector element(s) into place.
264   // For the mask set everything but element 0 to undef to prevent poison from
265   // propagating from the extra loaded memory. This will also optionally
266   // shrink/grow the vector from the loaded size to the output size.
267   // We assume this operation has no cost in codegen if there was no offset.
268   // Note that we could use freeze to avoid poison problems, but then we might
269   // still need a shuffle to change the vector size.
270   auto *Ty = cast<FixedVectorType>(I.getType());
271   unsigned OutputNumElts = Ty->getNumElements();
272   SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
273   assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
274   Mask[0] = OffsetEltIndex;
275   if (OffsetEltIndex)
276     NewCost +=
277         TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind);
278 
279   // We can aggressively convert to the vector form because the backend can
280   // invert this transform if it does not result in a performance win.
281   if (OldCost < NewCost || !NewCost.isValid())
282     return false;
283 
284   // It is safe and potentially profitable to load a vector directly:
285   // inselt undef, load Scalar, 0 --> load VecPtr
286   IRBuilder<> Builder(Load);
287   Value *CastedPtr =
288       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
289   Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
290   VecLd = Builder.CreateShuffleVector(VecLd, Mask);
291 
292   replaceValue(I, *VecLd);
293   ++NumVecLoad;
294   return true;
295 }
296 
297 /// If we are loading a vector and then inserting it into a larger vector with
298 /// undefined elements, try to load the larger vector and eliminate the insert.
299 /// This removes a shuffle in IR and may allow combining of other loaded values.
300 bool VectorCombine::widenSubvectorLoad(Instruction &I) {
301   // Match subvector insert of fixed vector.
302   auto *Shuf = cast<ShuffleVectorInst>(&I);
303   if (!Shuf->isIdentityWithPadding())
304     return false;
305 
306   // Allow a non-canonical shuffle mask that is choosing elements from op1.
307   unsigned NumOpElts =
308       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
309   unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
310     return M >= (int)(NumOpElts);
311   });
312 
313   auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
314   if (!canWidenLoad(Load, TTI))
315     return false;
316 
317   // We use minimal alignment (maximum flexibility) because we only care about
318   // the dereferenceable region. When calculating cost and creating a new op,
319   // we may use a larger value based on alignment attributes.
320   auto *Ty = cast<FixedVectorType>(I.getType());
321   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
322   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
323   Align Alignment = Load->getAlign();
324   if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT))
325     return false;
326 
327   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
328   Type *LoadTy = Load->getType();
329   unsigned AS = Load->getPointerAddressSpace();
330 
331   // Original pattern: insert_subvector (load PtrOp)
332   // This conservatively assumes that the cost of a subvector insert into an
333   // undef value is 0. We could add that cost if the cost model accurately
334   // reflects the real cost of that operation.
335   InstructionCost OldCost =
336       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
337 
338   // New pattern: load PtrOp
339   InstructionCost NewCost =
340       TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
341 
342   // We can aggressively convert to the vector form because the backend can
343   // invert this transform if it does not result in a performance win.
344   if (OldCost < NewCost || !NewCost.isValid())
345     return false;
346 
347   IRBuilder<> Builder(Load);
348   Value *CastedPtr =
349       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
350   Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
351   replaceValue(I, *VecLd);
352   ++NumVecLoad;
353   return true;
354 }
355 
356 /// Determine which, if any, of the inputs should be replaced by a shuffle
357 /// followed by extract from a different index.
358 ExtractElementInst *VectorCombine::getShuffleExtract(
359     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
360     unsigned PreferredExtractIndex = InvalidIndex) const {
361   auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
362   auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
363   assert(Index0C && Index1C && "Expected constant extract indexes");
364 
365   unsigned Index0 = Index0C->getZExtValue();
366   unsigned Index1 = Index1C->getZExtValue();
367 
368   // If the extract indexes are identical, no shuffle is needed.
369   if (Index0 == Index1)
370     return nullptr;
371 
372   Type *VecTy = Ext0->getVectorOperand()->getType();
373   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
374   InstructionCost Cost0 =
375       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
376   InstructionCost Cost1 =
377       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
378 
379   // If both costs are invalid no shuffle is needed
380   if (!Cost0.isValid() && !Cost1.isValid())
381     return nullptr;
382 
383   // We are extracting from 2 different indexes, so one operand must be shuffled
384   // before performing a vector operation and/or extract. The more expensive
385   // extract will be replaced by a shuffle.
386   if (Cost0 > Cost1)
387     return Ext0;
388   if (Cost1 > Cost0)
389     return Ext1;
390 
391   // If the costs are equal and there is a preferred extract index, shuffle the
392   // opposite operand.
393   if (PreferredExtractIndex == Index0)
394     return Ext1;
395   if (PreferredExtractIndex == Index1)
396     return Ext0;
397 
398   // Otherwise, replace the extract with the higher index.
399   return Index0 > Index1 ? Ext0 : Ext1;
400 }
401 
402 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
403 /// vector operation(s) followed by extract. Return true if the existing
404 /// instructions are cheaper than a vector alternative. Otherwise, return false
405 /// and if one of the extracts should be transformed to a shufflevector, set
406 /// \p ConvertToShuffle to that extract instruction.
407 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
408                                           ExtractElementInst *Ext1,
409                                           const Instruction &I,
410                                           ExtractElementInst *&ConvertToShuffle,
411                                           unsigned PreferredExtractIndex) {
412   auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
413   auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
414   assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
415 
416   unsigned Opcode = I.getOpcode();
417   Value *Ext0Src = Ext0->getVectorOperand();
418   Value *Ext1Src = Ext1->getVectorOperand();
419   Type *ScalarTy = Ext0->getType();
420   auto *VecTy = cast<VectorType>(Ext0Src->getType());
421   InstructionCost ScalarOpCost, VectorOpCost;
422 
423   // Get cost estimates for scalar and vector versions of the operation.
424   bool IsBinOp = Instruction::isBinaryOp(Opcode);
425   if (IsBinOp) {
426     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
427     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
428   } else {
429     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
430            "Expected a compare");
431     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
432     ScalarOpCost = TTI.getCmpSelInstrCost(
433         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
434     VectorOpCost = TTI.getCmpSelInstrCost(
435         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
436   }
437 
438   // Get cost estimates for the extract elements. These costs will factor into
439   // both sequences.
440   unsigned Ext0Index = Ext0IndexC->getZExtValue();
441   unsigned Ext1Index = Ext1IndexC->getZExtValue();
442 
443   InstructionCost Extract0Cost =
444       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
445   InstructionCost Extract1Cost =
446       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
447 
448   // A more expensive extract will always be replaced by a splat shuffle.
449   // For example, if Ext0 is more expensive:
450   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
451   // extelt (opcode (splat V0, Ext0), V1), Ext1
452   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
453   //       check the cost of creating a broadcast shuffle and shuffling both
454   //       operands to element 0.
455   unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
456   unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
457   InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
458 
459   // Extra uses of the extracts mean that we include those costs in the
460   // vector total because those instructions will not be eliminated.
461   InstructionCost OldCost, NewCost;
462   if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
463     // Handle a special case. If the 2 extracts are identical, adjust the
464     // formulas to account for that. The extra use charge allows for either the
465     // CSE'd pattern or an unoptimized form with identical values:
466     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
467     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
468                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
469     OldCost = CheapExtractCost + ScalarOpCost;
470     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
471   } else {
472     // Handle the general case. Each extract is actually a different value:
473     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
474     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
475     NewCost = VectorOpCost + CheapExtractCost +
476               !Ext0->hasOneUse() * Extract0Cost +
477               !Ext1->hasOneUse() * Extract1Cost;
478   }
479 
480   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
481   if (ConvertToShuffle) {
482     if (IsBinOp && DisableBinopExtractShuffle)
483       return true;
484 
485     // If we are extracting from 2 different indexes, then one operand must be
486     // shuffled before performing the vector operation. The shuffle mask is
487     // poison except for 1 lane that is being translated to the remaining
488     // extraction lane. Therefore, it is a splat shuffle. Ex:
489     // ShufMask = { poison, poison, 0, poison }
490     // TODO: The cost model has an option for a "broadcast" shuffle
491     //       (splat-from-element-0), but no option for a more general splat.
492     if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
493       SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
494                                    PoisonMaskElem);
495       ShuffleMask[BestInsIndex] = BestExtIndex;
496       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
497                                     VecTy, ShuffleMask, CostKind, 0, nullptr,
498                                     {ConvertToShuffle});
499     } else {
500       NewCost +=
501           TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
502                              {}, CostKind, 0, nullptr, {ConvertToShuffle});
503     }
504   }
505 
506   // Aggressively form a vector op if the cost is equal because the transform
507   // may enable further optimization.
508   // Codegen can reverse this transform (scalarize) if it was not profitable.
509   return OldCost < NewCost;
510 }
511 
512 /// Create a shuffle that translates (shifts) 1 element from the input vector
513 /// to a new element location.
514 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
515                                  unsigned NewIndex, IRBuilder<> &Builder) {
516   // The shuffle mask is poison except for 1 lane that is being translated
517   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
518   // ShufMask = { 2, poison, poison, poison }
519   auto *VecTy = cast<FixedVectorType>(Vec->getType());
520   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
521   ShufMask[NewIndex] = OldIndex;
522   return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
523 }
524 
525 /// Given an extract element instruction with constant index operand, shuffle
526 /// the source vector (shift the scalar element) to a NewIndex for extraction.
527 /// Return null if the input can be constant folded, so that we are not creating
528 /// unnecessary instructions.
529 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
530                                             unsigned NewIndex,
531                                             IRBuilder<> &Builder) {
532   // Shufflevectors can only be created for fixed-width vectors.
533   Value *X = ExtElt->getVectorOperand();
534   if (!isa<FixedVectorType>(X->getType()))
535     return nullptr;
536 
537   // If the extract can be constant-folded, this code is unsimplified. Defer
538   // to other passes to handle that.
539   Value *C = ExtElt->getIndexOperand();
540   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
541   if (isa<Constant>(X))
542     return nullptr;
543 
544   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
545                                    NewIndex, Builder);
546   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
547 }
548 
549 /// Try to reduce extract element costs by converting scalar compares to vector
550 /// compares followed by extract.
551 /// cmp (ext0 V0, C), (ext1 V1, C)
552 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
553                                   ExtractElementInst *Ext1, Instruction &I) {
554   assert(isa<CmpInst>(&I) && "Expected a compare");
555   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
556              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
557          "Expected matching constant extract indexes");
558 
559   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
560   ++NumVecCmp;
561   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
562   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
563   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
564   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
565   replaceValue(I, *NewExt);
566 }
567 
568 /// Try to reduce extract element costs by converting scalar binops to vector
569 /// binops followed by extract.
570 /// bo (ext0 V0, C), (ext1 V1, C)
571 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
572                                     ExtractElementInst *Ext1, Instruction &I) {
573   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
574   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
575              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
576          "Expected matching constant extract indexes");
577 
578   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
579   ++NumVecBO;
580   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
581   Value *VecBO =
582       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
583 
584   // All IR flags are safe to back-propagate because any potential poison
585   // created in unused vector elements is discarded by the extract.
586   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
587     VecBOInst->copyIRFlags(&I);
588 
589   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
590   replaceValue(I, *NewExt);
591 }
592 
593 /// Match an instruction with extracted vector operands.
594 bool VectorCombine::foldExtractExtract(Instruction &I) {
595   // It is not safe to transform things like div, urem, etc. because we may
596   // create undefined behavior when executing those on unknown vector elements.
597   if (!isSafeToSpeculativelyExecute(&I))
598     return false;
599 
600   Instruction *I0, *I1;
601   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
602   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
603       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
604     return false;
605 
606   Value *V0, *V1;
607   uint64_t C0, C1;
608   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
609       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
610       V0->getType() != V1->getType())
611     return false;
612 
613   // If the scalar value 'I' is going to be re-inserted into a vector, then try
614   // to create an extract to that same element. The extract/insert can be
615   // reduced to a "select shuffle".
616   // TODO: If we add a larger pattern match that starts from an insert, this
617   //       probably becomes unnecessary.
618   auto *Ext0 = cast<ExtractElementInst>(I0);
619   auto *Ext1 = cast<ExtractElementInst>(I1);
620   uint64_t InsertIndex = InvalidIndex;
621   if (I.hasOneUse())
622     match(I.user_back(),
623           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
624 
625   ExtractElementInst *ExtractToChange;
626   if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
627     return false;
628 
629   if (ExtractToChange) {
630     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
631     ExtractElementInst *NewExtract =
632         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
633     if (!NewExtract)
634       return false;
635     if (ExtractToChange == Ext0)
636       Ext0 = NewExtract;
637     else
638       Ext1 = NewExtract;
639   }
640 
641   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
642     foldExtExtCmp(Ext0, Ext1, I);
643   else
644     foldExtExtBinop(Ext0, Ext1, I);
645 
646   Worklist.push(Ext0);
647   Worklist.push(Ext1);
648   return true;
649 }
650 
651 /// Try to replace an extract + scalar fneg + insert with a vector fneg +
652 /// shuffle.
653 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
654   // Match an insert (op (extract)) pattern.
655   Value *DestVec;
656   uint64_t Index;
657   Instruction *FNeg;
658   if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
659                              m_ConstantInt(Index))))
660     return false;
661 
662   // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
663   Value *SrcVec;
664   Instruction *Extract;
665   if (!match(FNeg, m_FNeg(m_CombineAnd(
666                        m_Instruction(Extract),
667                        m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
668     return false;
669 
670   auto *VecTy = cast<FixedVectorType>(I.getType());
671   auto *ScalarTy = VecTy->getScalarType();
672   auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
673   if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
674     return false;
675 
676   // Ignore bogus insert/extract index.
677   unsigned NumElts = VecTy->getNumElements();
678   if (Index >= NumElts)
679     return false;
680 
681   // We are inserting the negated element into the same lane that we extracted
682   // from. This is equivalent to a select-shuffle that chooses all but the
683   // negated element from the destination vector.
684   SmallVector<int> Mask(NumElts);
685   std::iota(Mask.begin(), Mask.end(), 0);
686   Mask[Index] = Index + NumElts;
687   InstructionCost OldCost =
688       TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
689       TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
690 
691   // If the extract has one use, it will be eliminated, so count it in the
692   // original cost. If it has more than one use, ignore the cost because it will
693   // be the same before/after.
694   if (Extract->hasOneUse())
695     OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
696 
697   InstructionCost NewCost =
698       TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
699       TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind);
700 
701   bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
702   // If the lengths of the two vectors are not equal,
703   // we need to add a length-change vector. Add this cost.
704   SmallVector<int> SrcMask;
705   if (NeedLenChg) {
706     SrcMask.assign(NumElts, PoisonMaskElem);
707     SrcMask[Index] = Index;
708     NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
709                                   SrcVecTy, SrcMask, CostKind);
710   }
711 
712   if (NewCost > OldCost)
713     return false;
714 
715   Value *NewShuf;
716   // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
717   Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
718   if (NeedLenChg) {
719     // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
720     Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
721     NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask);
722   } else {
723     // shuffle DestVec, (fneg SrcVec), Mask
724     NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
725   }
726 
727   replaceValue(I, *NewShuf);
728   return true;
729 }
730 
731 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
732 /// destination type followed by shuffle. This can enable further transforms by
733 /// moving bitcasts or shuffles together.
734 bool VectorCombine::foldBitcastShuffle(Instruction &I) {
735   Value *V0, *V1;
736   ArrayRef<int> Mask;
737   if (!match(&I, m_BitCast(m_OneUse(
738                      m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
739     return false;
740 
741   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
742   // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
743   // mask for scalable type is a splat or not.
744   // 2) Disallow non-vector casts.
745   // TODO: We could allow any shuffle.
746   auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
747   auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
748   if (!DestTy || !SrcTy)
749     return false;
750 
751   unsigned DestEltSize = DestTy->getScalarSizeInBits();
752   unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
753   if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
754     return false;
755 
756   bool IsUnary = isa<UndefValue>(V1);
757 
758   // For binary shuffles, only fold bitcast(shuffle(X,Y))
759   // if it won't increase the number of bitcasts.
760   if (!IsUnary) {
761     auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
762     auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
763     if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
764         !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
765       return false;
766   }
767 
768   SmallVector<int, 16> NewMask;
769   if (DestEltSize <= SrcEltSize) {
770     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
771     // always be expanded to the equivalent form choosing narrower elements.
772     assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
773     unsigned ScaleFactor = SrcEltSize / DestEltSize;
774     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
775   } else {
776     // The bitcast is from narrow elements to wide elements. The shuffle mask
777     // must choose consecutive elements to allow casting first.
778     assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
779     unsigned ScaleFactor = DestEltSize / SrcEltSize;
780     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
781       return false;
782   }
783 
784   // Bitcast the shuffle src - keep its original width but using the destination
785   // scalar type.
786   unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
787   auto *NewShuffleTy =
788       FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
789   auto *OldShuffleTy =
790       FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
791   unsigned NumOps = IsUnary ? 1 : 2;
792 
793   // The new shuffle must not cost more than the old shuffle.
794   TargetTransformInfo::ShuffleKind SK =
795       IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
796               : TargetTransformInfo::SK_PermuteTwoSrc;
797 
798   InstructionCost NewCost =
799       TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) +
800       (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
801                                      TargetTransformInfo::CastContextHint::None,
802                                      CostKind));
803   InstructionCost OldCost =
804       TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) +
805       TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
806                            TargetTransformInfo::CastContextHint::None,
807                            CostKind);
808 
809   LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n  OldCost: "
810                     << OldCost << " vs NewCost: " << NewCost << "\n");
811 
812   if (NewCost > OldCost || !NewCost.isValid())
813     return false;
814 
815   // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
816   ++NumShufOfBitcast;
817   Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
818   Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
819   Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
820   replaceValue(I, *Shuf);
821   return true;
822 }
823 
824 /// VP Intrinsics whose vector operands are both splat values may be simplified
825 /// into the scalar version of the operation and the result splatted. This
826 /// can lead to scalarization down the line.
827 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
828   if (!isa<VPIntrinsic>(I))
829     return false;
830   VPIntrinsic &VPI = cast<VPIntrinsic>(I);
831   Value *Op0 = VPI.getArgOperand(0);
832   Value *Op1 = VPI.getArgOperand(1);
833 
834   if (!isSplatValue(Op0) || !isSplatValue(Op1))
835     return false;
836 
837   // Check getSplatValue early in this function, to avoid doing unnecessary
838   // work.
839   Value *ScalarOp0 = getSplatValue(Op0);
840   Value *ScalarOp1 = getSplatValue(Op1);
841   if (!ScalarOp0 || !ScalarOp1)
842     return false;
843 
844   // For the binary VP intrinsics supported here, the result on disabled lanes
845   // is a poison value. For now, only do this simplification if all lanes
846   // are active.
847   // TODO: Relax the condition that all lanes are active by using insertelement
848   // on inactive lanes.
849   auto IsAllTrueMask = [](Value *MaskVal) {
850     if (Value *SplattedVal = getSplatValue(MaskVal))
851       if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
852         return ConstValue->isAllOnesValue();
853     return false;
854   };
855   if (!IsAllTrueMask(VPI.getArgOperand(2)))
856     return false;
857 
858   // Check to make sure we support scalarization of the intrinsic
859   Intrinsic::ID IntrID = VPI.getIntrinsicID();
860   if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
861     return false;
862 
863   // Calculate cost of splatting both operands into vectors and the vector
864   // intrinsic
865   VectorType *VecTy = cast<VectorType>(VPI.getType());
866   SmallVector<int> Mask;
867   if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
868     Mask.resize(FVTy->getNumElements(), 0);
869   InstructionCost SplatCost =
870       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
871       TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask,
872                          CostKind);
873 
874   // Calculate the cost of the VP Intrinsic
875   SmallVector<Type *, 4> Args;
876   for (Value *V : VPI.args())
877     Args.push_back(V->getType());
878   IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
879   InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
880   InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
881 
882   // Determine scalar opcode
883   std::optional<unsigned> FunctionalOpcode =
884       VPI.getFunctionalOpcode();
885   std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
886   if (!FunctionalOpcode) {
887     ScalarIntrID = VPI.getFunctionalIntrinsicID();
888     if (!ScalarIntrID)
889       return false;
890   }
891 
892   // Calculate cost of scalarizing
893   InstructionCost ScalarOpCost = 0;
894   if (ScalarIntrID) {
895     IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
896     ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
897   } else {
898     ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
899                                               VecTy->getScalarType(), CostKind);
900   }
901 
902   // The existing splats may be kept around if other instructions use them.
903   InstructionCost CostToKeepSplats =
904       (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
905   InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
906 
907   LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
908                     << "\n");
909   LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
910                     << ", Cost of scalarizing:" << NewCost << "\n");
911 
912   // We want to scalarize unless the vector variant actually has lower cost.
913   if (OldCost < NewCost || !NewCost.isValid())
914     return false;
915 
916   // Scalarize the intrinsic
917   ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
918   Value *EVL = VPI.getArgOperand(3);
919 
920   // If the VP op might introduce UB or poison, we can scalarize it provided
921   // that we know the EVL > 0: If the EVL is zero, then the original VP op
922   // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
923   // scalarizing it.
924   bool SafeToSpeculate;
925   if (ScalarIntrID)
926     SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
927                           .hasFnAttr(Attribute::AttrKind::Speculatable);
928   else
929     SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
930         *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
931   if (!SafeToSpeculate &&
932       !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI)))
933     return false;
934 
935   Value *ScalarVal =
936       ScalarIntrID
937           ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
938                                     {ScalarOp0, ScalarOp1})
939           : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
940                                 ScalarOp0, ScalarOp1);
941 
942   replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
943   return true;
944 }
945 
946 /// Match a vector binop or compare instruction with at least one inserted
947 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
948 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
949   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
950   Value *Ins0, *Ins1;
951   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
952       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
953     return false;
954 
955   // Do not convert the vector condition of a vector select into a scalar
956   // condition. That may cause problems for codegen because of differences in
957   // boolean formats and register-file transfers.
958   // TODO: Can we account for that in the cost model?
959   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
960   if (IsCmp)
961     for (User *U : I.users())
962       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
963         return false;
964 
965   // Match against one or both scalar values being inserted into constant
966   // vectors:
967   // vec_op VecC0, (inselt VecC1, V1, Index)
968   // vec_op (inselt VecC0, V0, Index), VecC1
969   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
970   // TODO: Deal with mismatched index constants and variable indexes?
971   Constant *VecC0 = nullptr, *VecC1 = nullptr;
972   Value *V0 = nullptr, *V1 = nullptr;
973   uint64_t Index0 = 0, Index1 = 0;
974   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
975                                m_ConstantInt(Index0))) &&
976       !match(Ins0, m_Constant(VecC0)))
977     return false;
978   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
979                                m_ConstantInt(Index1))) &&
980       !match(Ins1, m_Constant(VecC1)))
981     return false;
982 
983   bool IsConst0 = !V0;
984   bool IsConst1 = !V1;
985   if (IsConst0 && IsConst1)
986     return false;
987   if (!IsConst0 && !IsConst1 && Index0 != Index1)
988     return false;
989 
990   auto *VecTy0 = cast<VectorType>(Ins0->getType());
991   auto *VecTy1 = cast<VectorType>(Ins1->getType());
992   if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
993       VecTy1->getElementCount().getKnownMinValue() <= Index1)
994     return false;
995 
996   // Bail for single insertion if it is a load.
997   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
998   auto *I0 = dyn_cast_or_null<Instruction>(V0);
999   auto *I1 = dyn_cast_or_null<Instruction>(V1);
1000   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
1001       (IsConst1 && I0 && I0->mayReadFromMemory()))
1002     return false;
1003 
1004   uint64_t Index = IsConst0 ? Index1 : Index0;
1005   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
1006   Type *VecTy = I.getType();
1007   assert(VecTy->isVectorTy() &&
1008          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
1009          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1010           ScalarTy->isPointerTy()) &&
1011          "Unexpected types for insert element into binop or cmp");
1012 
1013   unsigned Opcode = I.getOpcode();
1014   InstructionCost ScalarOpCost, VectorOpCost;
1015   if (IsCmp) {
1016     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
1017     ScalarOpCost = TTI.getCmpSelInstrCost(
1018         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1019     VectorOpCost = TTI.getCmpSelInstrCost(
1020         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1021   } else {
1022     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1023     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1024   }
1025 
1026   // Get cost estimate for the insert element. This cost will factor into
1027   // both sequences.
1028   InstructionCost InsertCost = TTI.getVectorInstrCost(
1029       Instruction::InsertElement, VecTy, CostKind, Index);
1030   InstructionCost OldCost =
1031       (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
1032   InstructionCost NewCost = ScalarOpCost + InsertCost +
1033                             (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
1034                             (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
1035 
1036   // We want to scalarize unless the vector variant actually has lower cost.
1037   if (OldCost < NewCost || !NewCost.isValid())
1038     return false;
1039 
1040   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1041   // inselt NewVecC, (scalar_op V0, V1), Index
1042   if (IsCmp)
1043     ++NumScalarCmp;
1044   else
1045     ++NumScalarBO;
1046 
1047   // For constant cases, extract the scalar element, this should constant fold.
1048   if (IsConst0)
1049     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1050   if (IsConst1)
1051     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1052 
1053   Value *Scalar =
1054       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
1055             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
1056 
1057   Scalar->setName(I.getName() + ".scalar");
1058 
1059   // All IR flags are safe to back-propagate. There is no potential for extra
1060   // poison to be created by the scalar instruction.
1061   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1062     ScalarInst->copyIRFlags(&I);
1063 
1064   // Fold the vector constants in the original vectors into a new base vector.
1065   Value *NewVecC =
1066       IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1067             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1068   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1069   replaceValue(I, *Insert);
1070   return true;
1071 }
1072 
1073 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1074 /// a vector into vector operations followed by extract. Note: The SLP pass
1075 /// may miss this pattern because of implementation problems.
1076 bool VectorCombine::foldExtractedCmps(Instruction &I) {
1077   auto *BI = dyn_cast<BinaryOperator>(&I);
1078 
1079   // We are looking for a scalar binop of booleans.
1080   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1081   if (!BI || !I.getType()->isIntegerTy(1))
1082     return false;
1083 
1084   // The compare predicates should match, and each compare should have a
1085   // constant operand.
1086   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1087   Instruction *I0, *I1;
1088   Constant *C0, *C1;
1089   CmpPredicate P0, P1;
1090   // FIXME: Use CmpPredicate::getMatching here.
1091   if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1092       !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) ||
1093       P0 != static_cast<CmpInst::Predicate>(P1))
1094     return false;
1095 
1096   // The compare operands must be extracts of the same vector with constant
1097   // extract indexes.
1098   Value *X;
1099   uint64_t Index0, Index1;
1100   if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1101       !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1102     return false;
1103 
1104   auto *Ext0 = cast<ExtractElementInst>(I0);
1105   auto *Ext1 = cast<ExtractElementInst>(I1);
1106   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1107   if (!ConvertToShuf)
1108     return false;
1109   assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1110          "Unknown ExtractElementInst");
1111 
1112   // The original scalar pattern is:
1113   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1114   CmpInst::Predicate Pred = P0;
1115   unsigned CmpOpcode =
1116       CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1117   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1118   if (!VecTy)
1119     return false;
1120 
1121   InstructionCost Ext0Cost =
1122       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1123   InstructionCost Ext1Cost =
1124       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1125   InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1126       CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1127       CostKind);
1128 
1129   InstructionCost OldCost =
1130       Ext0Cost + Ext1Cost + CmpCost * 2 +
1131       TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1132 
1133   // The proposed vector pattern is:
1134   // vcmp = cmp Pred X, VecC
1135   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1136   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1137   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1138   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
1139   InstructionCost NewCost = TTI.getCmpSelInstrCost(
1140       CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred,
1141       CostKind);
1142   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1143   ShufMask[CheapIndex] = ExpensiveIndex;
1144   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1145                                 ShufMask, CostKind);
1146   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1147   NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1148   NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1149   NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1150 
1151   // Aggressively form vector ops if the cost is equal because the transform
1152   // may enable further optimization.
1153   // Codegen can reverse this transform (scalarize) if it was not profitable.
1154   if (OldCost < NewCost || !NewCost.isValid())
1155     return false;
1156 
1157   // Create a vector constant from the 2 scalar constants.
1158   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1159                                    PoisonValue::get(VecTy->getElementType()));
1160   CmpC[Index0] = C0;
1161   CmpC[Index1] = C1;
1162   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1163   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1164   Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1165   Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1166   Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1167   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1168   replaceValue(I, *NewExt);
1169   ++NumVecCmpBO;
1170   return true;
1171 }
1172 
1173 // Check if memory loc modified between two instrs in the same BB
1174 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1175                                  BasicBlock::iterator End,
1176                                  const MemoryLocation &Loc, AAResults &AA) {
1177   unsigned NumScanned = 0;
1178   return std::any_of(Begin, End, [&](const Instruction &Instr) {
1179     return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1180            ++NumScanned > MaxInstrsToScan;
1181   });
1182 }
1183 
1184 namespace {
1185 /// Helper class to indicate whether a vector index can be safely scalarized and
1186 /// if a freeze needs to be inserted.
1187 class ScalarizationResult {
1188   enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1189 
1190   StatusTy Status;
1191   Value *ToFreeze;
1192 
1193   ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1194       : Status(Status), ToFreeze(ToFreeze) {}
1195 
1196 public:
1197   ScalarizationResult(const ScalarizationResult &Other) = default;
1198   ~ScalarizationResult() {
1199     assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1200   }
1201 
1202   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1203   static ScalarizationResult safe() { return {StatusTy::Safe}; }
1204   static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1205     return {StatusTy::SafeWithFreeze, ToFreeze};
1206   }
1207 
1208   /// Returns true if the index can be scalarize without requiring a freeze.
1209   bool isSafe() const { return Status == StatusTy::Safe; }
1210   /// Returns true if the index cannot be scalarized.
1211   bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1212   /// Returns true if the index can be scalarize, but requires inserting a
1213   /// freeze.
1214   bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1215 
1216   /// Reset the state of Unsafe and clear ToFreze if set.
1217   void discard() {
1218     ToFreeze = nullptr;
1219     Status = StatusTy::Unsafe;
1220   }
1221 
1222   /// Freeze the ToFreeze and update the use in \p User to use it.
1223   void freeze(IRBuilder<> &Builder, Instruction &UserI) {
1224     assert(isSafeWithFreeze() &&
1225            "should only be used when freezing is required");
1226     assert(is_contained(ToFreeze->users(), &UserI) &&
1227            "UserI must be a user of ToFreeze");
1228     IRBuilder<>::InsertPointGuard Guard(Builder);
1229     Builder.SetInsertPoint(cast<Instruction>(&UserI));
1230     Value *Frozen =
1231         Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1232     for (Use &U : make_early_inc_range((UserI.operands())))
1233       if (U.get() == ToFreeze)
1234         U.set(Frozen);
1235 
1236     ToFreeze = nullptr;
1237   }
1238 };
1239 } // namespace
1240 
1241 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1242 /// Idx. \p Idx must access a valid vector element.
1243 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1244                                               Instruction *CtxI,
1245                                               AssumptionCache &AC,
1246                                               const DominatorTree &DT) {
1247   // We do checks for both fixed vector types and scalable vector types.
1248   // This is the number of elements of fixed vector types,
1249   // or the minimum number of elements of scalable vector types.
1250   uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1251 
1252   if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1253     if (C->getValue().ult(NumElements))
1254       return ScalarizationResult::safe();
1255     return ScalarizationResult::unsafe();
1256   }
1257 
1258   unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1259   APInt Zero(IntWidth, 0);
1260   APInt MaxElts(IntWidth, NumElements);
1261   ConstantRange ValidIndices(Zero, MaxElts);
1262   ConstantRange IdxRange(IntWidth, true);
1263 
1264   if (isGuaranteedNotToBePoison(Idx, &AC)) {
1265     if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
1266                                                    true, &AC, CtxI, &DT)))
1267       return ScalarizationResult::safe();
1268     return ScalarizationResult::unsafe();
1269   }
1270 
1271   // If the index may be poison, check if we can insert a freeze before the
1272   // range of the index is restricted.
1273   Value *IdxBase;
1274   ConstantInt *CI;
1275   if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1276     IdxRange = IdxRange.binaryAnd(CI->getValue());
1277   } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1278     IdxRange = IdxRange.urem(CI->getValue());
1279   }
1280 
1281   if (ValidIndices.contains(IdxRange))
1282     return ScalarizationResult::safeWithFreeze(IdxBase);
1283   return ScalarizationResult::unsafe();
1284 }
1285 
1286 /// The memory operation on a vector of \p ScalarType had alignment of
1287 /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1288 /// alignment that will be valid for the memory operation on a single scalar
1289 /// element of the same type with index \p Idx.
1290 static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1291                                                 Type *ScalarType, Value *Idx,
1292                                                 const DataLayout &DL) {
1293   if (auto *C = dyn_cast<ConstantInt>(Idx))
1294     return commonAlignment(VectorAlignment,
1295                            C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1296   return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1297 }
1298 
1299 // Combine patterns like:
1300 //   %0 = load <4 x i32>, <4 x i32>* %a
1301 //   %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1302 //   store <4 x i32> %1, <4 x i32>* %a
1303 // to:
1304 //   %0 = bitcast <4 x i32>* %a to i32*
1305 //   %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1306 //   store i32 %b, i32* %1
1307 bool VectorCombine::foldSingleElementStore(Instruction &I) {
1308   auto *SI = cast<StoreInst>(&I);
1309   if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1310     return false;
1311 
1312   // TODO: Combine more complicated patterns (multiple insert) by referencing
1313   // TargetTransformInfo.
1314   Instruction *Source;
1315   Value *NewElement;
1316   Value *Idx;
1317   if (!match(SI->getValueOperand(),
1318              m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1319                          m_Value(Idx))))
1320     return false;
1321 
1322   if (auto *Load = dyn_cast<LoadInst>(Source)) {
1323     auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1324     Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1325     // Don't optimize for atomic/volatile load or store. Ensure memory is not
1326     // modified between, vector type matches store size, and index is inbounds.
1327     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1328         !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1329         SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1330       return false;
1331 
1332     auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1333     if (ScalarizableIdx.isUnsafe() ||
1334         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1335                              MemoryLocation::get(SI), AA))
1336       return false;
1337 
1338     if (ScalarizableIdx.isSafeWithFreeze())
1339       ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1340     Value *GEP = Builder.CreateInBoundsGEP(
1341         SI->getValueOperand()->getType(), SI->getPointerOperand(),
1342         {ConstantInt::get(Idx->getType(), 0), Idx});
1343     StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1344     NSI->copyMetadata(*SI);
1345     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1346         std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1347         *DL);
1348     NSI->setAlignment(ScalarOpAlignment);
1349     replaceValue(I, *NSI);
1350     eraseInstruction(I);
1351     return true;
1352   }
1353 
1354   return false;
1355 }
1356 
1357 /// Try to scalarize vector loads feeding extractelement instructions.
1358 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1359   Value *Ptr;
1360   if (!match(&I, m_Load(m_Value(Ptr))))
1361     return false;
1362 
1363   auto *VecTy = cast<VectorType>(I.getType());
1364   auto *LI = cast<LoadInst>(&I);
1365   if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1366     return false;
1367 
1368   InstructionCost OriginalCost =
1369       TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1370                           LI->getPointerAddressSpace(), CostKind);
1371   InstructionCost ScalarizedCost = 0;
1372 
1373   Instruction *LastCheckedInst = LI;
1374   unsigned NumInstChecked = 0;
1375   DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1376   auto FailureGuard = make_scope_exit([&]() {
1377     // If the transform is aborted, discard the ScalarizationResults.
1378     for (auto &Pair : NeedFreeze)
1379       Pair.second.discard();
1380   });
1381 
1382   // Check if all users of the load are extracts with no memory modifications
1383   // between the load and the extract. Compute the cost of both the original
1384   // code and the scalarized version.
1385   for (User *U : LI->users()) {
1386     auto *UI = dyn_cast<ExtractElementInst>(U);
1387     if (!UI || UI->getParent() != LI->getParent())
1388       return false;
1389 
1390     // Check if any instruction between the load and the extract may modify
1391     // memory.
1392     if (LastCheckedInst->comesBefore(UI)) {
1393       for (Instruction &I :
1394            make_range(std::next(LI->getIterator()), UI->getIterator())) {
1395         // Bail out if we reached the check limit or the instruction may write
1396         // to memory.
1397         if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1398           return false;
1399         NumInstChecked++;
1400       }
1401       LastCheckedInst = UI;
1402     }
1403 
1404     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
1405     if (ScalarIdx.isUnsafe())
1406       return false;
1407     if (ScalarIdx.isSafeWithFreeze()) {
1408       NeedFreeze.try_emplace(UI, ScalarIdx);
1409       ScalarIdx.discard();
1410     }
1411 
1412     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1413     OriginalCost +=
1414         TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1415                                Index ? Index->getZExtValue() : -1);
1416     ScalarizedCost +=
1417         TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1418                             Align(1), LI->getPointerAddressSpace(), CostKind);
1419     ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1420   }
1421 
1422   if (ScalarizedCost >= OriginalCost)
1423     return false;
1424 
1425   // Replace extracts with narrow scalar loads.
1426   for (User *U : LI->users()) {
1427     auto *EI = cast<ExtractElementInst>(U);
1428     Value *Idx = EI->getOperand(1);
1429 
1430     // Insert 'freeze' for poison indexes.
1431     auto It = NeedFreeze.find(EI);
1432     if (It != NeedFreeze.end())
1433       It->second.freeze(Builder, *cast<Instruction>(Idx));
1434 
1435     Builder.SetInsertPoint(EI);
1436     Value *GEP =
1437         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1438     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1439         VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1440 
1441     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1442         LI->getAlign(), VecTy->getElementType(), Idx, *DL);
1443     NewLoad->setAlignment(ScalarOpAlignment);
1444 
1445     replaceValue(*EI, *NewLoad);
1446   }
1447 
1448   FailureGuard.release();
1449   return true;
1450 }
1451 
1452 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1453 /// to "(bitcast (concat X, Y))"
1454 /// where X/Y are bitcasted from i1 mask vectors.
1455 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1456   Type *Ty = I.getType();
1457   if (!Ty->isIntegerTy())
1458     return false;
1459 
1460   // TODO: Add big endian test coverage
1461   if (DL->isBigEndian())
1462     return false;
1463 
1464   // Restrict to disjoint cases so the mask vectors aren't overlapping.
1465   Instruction *X, *Y;
1466   if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1467     return false;
1468 
1469   // Allow both sources to contain shl, to handle more generic pattern:
1470   // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1471   Value *SrcX;
1472   uint64_t ShAmtX = 0;
1473   if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1474       !match(X, m_OneUse(
1475                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1476                           m_ConstantInt(ShAmtX)))))
1477     return false;
1478 
1479   Value *SrcY;
1480   uint64_t ShAmtY = 0;
1481   if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1482       !match(Y, m_OneUse(
1483                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1484                           m_ConstantInt(ShAmtY)))))
1485     return false;
1486 
1487   // Canonicalize larger shift to the RHS.
1488   if (ShAmtX > ShAmtY) {
1489     std::swap(X, Y);
1490     std::swap(SrcX, SrcY);
1491     std::swap(ShAmtX, ShAmtY);
1492   }
1493 
1494   // Ensure both sources are matching vXi1 bool mask types, and that the shift
1495   // difference is the mask width so they can be easily concatenated together.
1496   uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1497   unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1498   unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1499   auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1500   if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1501       !MaskTy->getElementType()->isIntegerTy(1) ||
1502       MaskTy->getNumElements() != ShAmtDiff ||
1503       MaskTy->getNumElements() > (BitWidth / 2))
1504     return false;
1505 
1506   auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1507   auto *ConcatIntTy =
1508       Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1509   auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1510 
1511   SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1512   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1513 
1514   // TODO: Is it worth supporting multi use cases?
1515   InstructionCost OldCost = 0;
1516   OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1517   OldCost +=
1518       NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1519   OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1520                                       TTI::CastContextHint::None, CostKind);
1521   OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1522                                       TTI::CastContextHint::None, CostKind);
1523 
1524   InstructionCost NewCost = 0;
1525   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1526                                 ConcatMask, CostKind);
1527   NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1528                                   TTI::CastContextHint::None, CostKind);
1529   if (Ty != ConcatIntTy)
1530     NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1531                                     TTI::CastContextHint::None, CostKind);
1532   if (ShAmtX > 0)
1533     NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1534 
1535   if (NewCost > OldCost)
1536     return false;
1537 
1538   // Build bool mask concatenation, bitcast back to scalar integer, and perform
1539   // any residual zero-extension or shifting.
1540   Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1541   Worklist.pushValue(Concat);
1542 
1543   Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1544 
1545   if (Ty != ConcatIntTy) {
1546     Worklist.pushValue(Result);
1547     Result = Builder.CreateZExt(Result, Ty);
1548   }
1549 
1550   if (ShAmtX > 0) {
1551     Worklist.pushValue(Result);
1552     Result = Builder.CreateShl(Result, ShAmtX);
1553   }
1554 
1555   replaceValue(I, *Result);
1556   return true;
1557 }
1558 
1559 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1560 ///           -->  "binop (shuffle), (shuffle)".
1561 bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1562   BinaryOperator *BinOp;
1563   ArrayRef<int> OuterMask;
1564   if (!match(&I,
1565              m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
1566     return false;
1567 
1568   // Don't introduce poison into div/rem.
1569   if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
1570     return false;
1571 
1572   Value *Op00, *Op01;
1573   ArrayRef<int> Mask0;
1574   if (!match(BinOp->getOperand(0),
1575              m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
1576     return false;
1577 
1578   Value *Op10, *Op11;
1579   ArrayRef<int> Mask1;
1580   if (!match(BinOp->getOperand(1),
1581              m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
1582     return false;
1583 
1584   Instruction::BinaryOps Opcode = BinOp->getOpcode();
1585   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1586   auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
1587   auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
1588   auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
1589   if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1590     return false;
1591 
1592   unsigned NumSrcElts = BinOpTy->getNumElements();
1593 
1594   // Don't accept shuffles that reference the second operand in
1595   // div/rem or if its an undef arg.
1596   if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
1597       any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
1598     return false;
1599 
1600   // Merge outer / inner shuffles.
1601   SmallVector<int> NewMask0, NewMask1;
1602   for (int M : OuterMask) {
1603     if (M < 0 || M >= (int)NumSrcElts) {
1604       NewMask0.push_back(PoisonMaskElem);
1605       NewMask1.push_back(PoisonMaskElem);
1606     } else {
1607       NewMask0.push_back(Mask0[M]);
1608       NewMask1.push_back(Mask1[M]);
1609     }
1610   }
1611 
1612   // Try to merge shuffles across the binop if the new shuffles are not costly.
1613   InstructionCost OldCost =
1614       TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
1615       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1616                          OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
1617       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1618                          CostKind, 0, nullptr, {Op00, Op01},
1619                          cast<Instruction>(BinOp->getOperand(0))) +
1620       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1621                          CostKind, 0, nullptr, {Op10, Op11},
1622                          cast<Instruction>(BinOp->getOperand(1)));
1623 
1624   InstructionCost NewCost =
1625       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1626                          CostKind, 0, nullptr, {Op00, Op01}) +
1627       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1628                          CostKind, 0, nullptr, {Op10, Op11}) +
1629       TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
1630 
1631   LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
1632                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1633                     << "\n");
1634 
1635   // If costs are equal, still fold as we reduce instruction count.
1636   if (NewCost > OldCost)
1637     return false;
1638 
1639   Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1640   Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1641   Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1642 
1643   // Intersect flags from the old binops.
1644   if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1645     NewInst->copyIRFlags(BinOp);
1646 
1647   Worklist.pushValue(Shuf0);
1648   Worklist.pushValue(Shuf1);
1649   replaceValue(I, *NewBO);
1650   return true;
1651 }
1652 
1653 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1654 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
1655 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1656   ArrayRef<int> OldMask;
1657   Instruction *LHS, *RHS;
1658   if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)),
1659                            m_OneUse(m_Instruction(RHS)), m_Mask(OldMask))))
1660     return false;
1661 
1662   // TODO: Add support for addlike etc.
1663   if (LHS->getOpcode() != RHS->getOpcode())
1664     return false;
1665 
1666   Value *X, *Y, *Z, *W;
1667   bool IsCommutative = false;
1668   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1669   if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
1670       match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
1671     auto *BO = cast<BinaryOperator>(LHS);
1672     // Don't introduce poison into div/rem.
1673     if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
1674       return false;
1675     IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
1676   } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
1677              match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) {
1678     IsCommutative = cast<CmpInst>(LHS)->isCommutative();
1679   } else
1680     return false;
1681 
1682   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1683   auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
1684   auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
1685   if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
1686     return false;
1687 
1688   unsigned NumSrcElts = BinOpTy->getNumElements();
1689 
1690   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1691   if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
1692     std::swap(X, Y);
1693 
1694   auto ConvertToUnary = [NumSrcElts](int &M) {
1695     if (M >= (int)NumSrcElts)
1696       M -= NumSrcElts;
1697   };
1698 
1699   SmallVector<int> NewMask0(OldMask);
1700   TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
1701   if (X == Z) {
1702     llvm::for_each(NewMask0, ConvertToUnary);
1703     SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
1704     Z = PoisonValue::get(BinOpTy);
1705   }
1706 
1707   SmallVector<int> NewMask1(OldMask);
1708   TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
1709   if (Y == W) {
1710     llvm::for_each(NewMask1, ConvertToUnary);
1711     SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
1712     W = PoisonValue::get(BinOpTy);
1713   }
1714 
1715   // Try to replace a binop with a shuffle if the shuffle is not costly.
1716   InstructionCost OldCost =
1717       TTI.getInstructionCost(LHS, CostKind) +
1718       TTI.getInstructionCost(RHS, CostKind) +
1719       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy,
1720                          OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I);
1721 
1722   InstructionCost NewCost =
1723       TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
1724       TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W});
1725 
1726   if (Pred == CmpInst::BAD_ICMP_PREDICATE) {
1727     NewCost +=
1728         TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
1729   } else {
1730     auto *ShuffleCmpTy =
1731         FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
1732     NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
1733                                       ShuffleDstTy, Pred, CostKind);
1734   }
1735 
1736   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
1737                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1738                     << "\n");
1739 
1740   // If either shuffle will constant fold away, then fold for the same cost as
1741   // we will reduce the instruction count.
1742   bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) ||
1743                           (isa<Constant>(Y) && isa<Constant>(W));
1744   if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
1745     return false;
1746 
1747   Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
1748   Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
1749   Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE
1750                      ? Builder.CreateBinOp(
1751                            cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
1752                      : Builder.CreateCmp(Pred, Shuf0, Shuf1);
1753 
1754   // Intersect flags from the old binops.
1755   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1756     NewInst->copyIRFlags(LHS);
1757     NewInst->andIRFlags(RHS);
1758   }
1759 
1760   Worklist.pushValue(Shuf0);
1761   Worklist.pushValue(Shuf1);
1762   replaceValue(I, *NewBO);
1763   return true;
1764 }
1765 
1766 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
1767 /// into "castop (shuffle)".
1768 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1769   Value *V0, *V1;
1770   ArrayRef<int> OldMask;
1771   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
1772     return false;
1773 
1774   auto *C0 = dyn_cast<CastInst>(V0);
1775   auto *C1 = dyn_cast<CastInst>(V1);
1776   if (!C0 || !C1)
1777     return false;
1778 
1779   Instruction::CastOps Opcode = C0->getOpcode();
1780   if (C0->getSrcTy() != C1->getSrcTy())
1781     return false;
1782 
1783   // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
1784   if (Opcode != C1->getOpcode()) {
1785     if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
1786       Opcode = Instruction::SExt;
1787     else
1788       return false;
1789   }
1790 
1791   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1792   auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
1793   auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
1794   if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
1795     return false;
1796 
1797   unsigned NumSrcElts = CastSrcTy->getNumElements();
1798   unsigned NumDstElts = CastDstTy->getNumElements();
1799   assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
1800          "Only bitcasts expected to alter src/dst element counts");
1801 
1802   // Check for bitcasting of unscalable vector types.
1803   // e.g. <32 x i40> -> <40 x i32>
1804   if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
1805       (NumDstElts % NumSrcElts) != 0)
1806     return false;
1807 
1808   SmallVector<int, 16> NewMask;
1809   if (NumSrcElts >= NumDstElts) {
1810     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1811     // always be expanded to the equivalent form choosing narrower elements.
1812     assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
1813     unsigned ScaleFactor = NumSrcElts / NumDstElts;
1814     narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
1815   } else {
1816     // The bitcast is from narrow elements to wide elements. The shuffle mask
1817     // must choose consecutive elements to allow casting first.
1818     assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
1819     unsigned ScaleFactor = NumDstElts / NumSrcElts;
1820     if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
1821       return false;
1822   }
1823 
1824   auto *NewShuffleDstTy =
1825       FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
1826 
1827   // Try to replace a castop with a shuffle if the shuffle is not costly.
1828   InstructionCost CostC0 =
1829       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
1830                            TTI::CastContextHint::None, CostKind);
1831   InstructionCost CostC1 =
1832       TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
1833                            TTI::CastContextHint::None, CostKind);
1834   InstructionCost OldCost = CostC0 + CostC1;
1835   OldCost +=
1836       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
1837                          OldMask, CostKind, 0, nullptr, {}, &I);
1838 
1839   InstructionCost NewCost = TTI.getShuffleCost(
1840       TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
1841   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
1842                                   TTI::CastContextHint::None, CostKind);
1843   if (!C0->hasOneUse())
1844     NewCost += CostC0;
1845   if (!C1->hasOneUse())
1846     NewCost += CostC1;
1847 
1848   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
1849                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1850                     << "\n");
1851   if (NewCost > OldCost)
1852     return false;
1853 
1854   Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
1855                                             C1->getOperand(0), NewMask);
1856   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
1857 
1858   // Intersect flags from the old casts.
1859   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
1860     NewInst->copyIRFlags(C0);
1861     NewInst->andIRFlags(C1);
1862   }
1863 
1864   Worklist.pushValue(Shuf);
1865   replaceValue(I, *Cast);
1866   return true;
1867 }
1868 
1869 /// Try to convert any of:
1870 /// "shuffle (shuffle x, undef), (shuffle y, undef)"
1871 /// "shuffle (shuffle x, undef), y"
1872 /// "shuffle x, (shuffle y, undef)"
1873 /// into "shuffle x, y".
1874 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1875   ArrayRef<int> OuterMask;
1876   Value *OuterV0, *OuterV1;
1877   if (!match(&I,
1878              m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
1879     return false;
1880 
1881   ArrayRef<int> InnerMask0, InnerMask1;
1882   Value *V0 = nullptr, *V1 = nullptr;
1883   UndefValue *U0 = nullptr, *U1 = nullptr;
1884   bool Match0 = match(
1885       OuterV0, m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0)));
1886   bool Match1 = match(
1887       OuterV1, m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1)));
1888   if (!Match0 && !Match1)
1889     return false;
1890 
1891   V0 = Match0 ? V0 : OuterV0;
1892   V1 = Match1 ? V1 : OuterV1;
1893   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1894   auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType());
1895   auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType());
1896   if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1897       V0->getType() != V1->getType())
1898     return false;
1899 
1900   unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
1901   unsigned NumImmElts = ShuffleImmTy->getNumElements();
1902 
1903   // Bail if either inner masks reference a RHS undef arg.
1904   if ((Match0 && !isa<PoisonValue>(U0) &&
1905        any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) ||
1906       (Match1 && !isa<PoisonValue>(U1) &&
1907        any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; })))
1908     return false;
1909 
1910   // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem,
1911   SmallVector<int, 16> NewMask(OuterMask);
1912   for (int &M : NewMask) {
1913     if (0 <= M && M < (int)NumImmElts) {
1914       if (Match0)
1915         M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M];
1916     } else if (M >= (int)NumImmElts) {
1917       if (Match1) {
1918         if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts)
1919           M = PoisonMaskElem;
1920         else
1921           M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts);
1922       }
1923     }
1924   }
1925 
1926   // Have we folded to an Identity shuffle?
1927   if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
1928     replaceValue(I, *V0);
1929     return true;
1930   }
1931 
1932   // Try to merge the shuffles if the new shuffle is not costly.
1933   InstructionCost InnerCost0 = 0;
1934   if (Match0)
1935     InnerCost0 = TTI.getShuffleCost(
1936         TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask0,
1937         CostKind, 0, nullptr, {V0, U0}, cast<ShuffleVectorInst>(OuterV0));
1938 
1939   InstructionCost InnerCost1 = 0;
1940   if (Match1)
1941     InnerCost1 = TTI.getShuffleCost(
1942         TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask1,
1943         CostKind, 0, nullptr, {V1, U1}, cast<ShuffleVectorInst>(OuterV1));
1944 
1945   InstructionCost OuterCost = TTI.getShuffleCost(
1946       TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind,
1947       0, nullptr, {OuterV0, OuterV1}, &I);
1948 
1949   InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
1950 
1951   InstructionCost NewCost =
1952       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy,
1953                          NewMask, CostKind, 0, nullptr, {V0, V1});
1954   if (!OuterV0->hasOneUse())
1955     NewCost += InnerCost0;
1956   if (!OuterV1->hasOneUse())
1957     NewCost += InnerCost1;
1958 
1959   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
1960                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1961                     << "\n");
1962   if (NewCost > OldCost)
1963     return false;
1964 
1965   // Clear unused sources to poison.
1966   if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; }))
1967     V0 = PoisonValue::get(ShuffleSrcTy);
1968   if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; }))
1969     V1 = PoisonValue::get(ShuffleSrcTy);
1970 
1971   Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask);
1972   replaceValue(I, *Shuf);
1973   return true;
1974 }
1975 
1976 /// Try to convert
1977 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
1978 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
1979   Value *V0, *V1;
1980   ArrayRef<int> OldMask;
1981   if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
1982                            m_Mask(OldMask))))
1983     return false;
1984 
1985   auto *II0 = dyn_cast<IntrinsicInst>(V0);
1986   auto *II1 = dyn_cast<IntrinsicInst>(V1);
1987   if (!II0 || !II1)
1988     return false;
1989 
1990   Intrinsic::ID IID = II0->getIntrinsicID();
1991   if (IID != II1->getIntrinsicID())
1992     return false;
1993 
1994   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1995   auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
1996   if (!ShuffleDstTy || !II0Ty)
1997     return false;
1998 
1999   if (!isTriviallyVectorizable(IID))
2000     return false;
2001 
2002   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2003     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) &&
2004         II0->getArgOperand(I) != II1->getArgOperand(I))
2005       return false;
2006 
2007   InstructionCost OldCost =
2008       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
2009       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) +
2010       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask,
2011                          CostKind, 0, nullptr, {II0, II1}, &I);
2012 
2013   SmallVector<Type *> NewArgsTy;
2014   InstructionCost NewCost = 0;
2015   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2016     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2017       NewArgsTy.push_back(II0->getArgOperand(I)->getType());
2018     } else {
2019       auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
2020       NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(),
2021                                                VecTy->getNumElements() * 2));
2022       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
2023                                     VecTy, OldMask, CostKind);
2024     }
2025   IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
2026   NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
2027 
2028   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
2029                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
2030                     << "\n");
2031 
2032   if (NewCost > OldCost)
2033     return false;
2034 
2035   SmallVector<Value *> NewArgs;
2036   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2037     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2038       NewArgs.push_back(II0->getArgOperand(I));
2039     } else {
2040       Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
2041                                                 II1->getArgOperand(I), OldMask);
2042       NewArgs.push_back(Shuf);
2043       Worklist.pushValue(Shuf);
2044     }
2045   Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
2046 
2047   // Intersect flags from the old intrinsics.
2048   if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
2049     NewInst->copyIRFlags(II0);
2050     NewInst->andIRFlags(II1);
2051   }
2052 
2053   replaceValue(I, *NewIntrinsic);
2054   return true;
2055 }
2056 
2057 using InstLane = std::pair<Use *, int>;
2058 
2059 static InstLane lookThroughShuffles(Use *U, int Lane) {
2060   while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) {
2061     unsigned NumElts =
2062         cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
2063     int M = SV->getMaskValue(Lane);
2064     if (M < 0)
2065       return {nullptr, PoisonMaskElem};
2066     if (static_cast<unsigned>(M) < NumElts) {
2067       U = &SV->getOperandUse(0);
2068       Lane = M;
2069     } else {
2070       U = &SV->getOperandUse(1);
2071       Lane = M - NumElts;
2072     }
2073   }
2074   return InstLane{U, Lane};
2075 }
2076 
2077 static SmallVector<InstLane>
2078 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
2079   SmallVector<InstLane> NItem;
2080   for (InstLane IL : Item) {
2081     auto [U, Lane] = IL;
2082     InstLane OpLane =
2083         U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op),
2084                                 Lane)
2085           : InstLane{nullptr, PoisonMaskElem};
2086     NItem.emplace_back(OpLane);
2087   }
2088   return NItem;
2089 }
2090 
2091 /// Detect concat of multiple values into a vector
2092 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
2093                          const TargetTransformInfo &TTI) {
2094   auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
2095   unsigned NumElts = Ty->getNumElements();
2096   if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
2097     return false;
2098 
2099   // Check that the concat is free, usually meaning that the type will be split
2100   // during legalization.
2101   SmallVector<int, 16> ConcatMask(NumElts * 2);
2102   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2103   if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0)
2104     return false;
2105 
2106   unsigned NumSlices = Item.size() / NumElts;
2107   // Currently we generate a tree of shuffles for the concats, which limits us
2108   // to a power2.
2109   if (!isPowerOf2_32(NumSlices))
2110     return false;
2111   for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
2112     Use *SliceV = Item[Slice * NumElts].first;
2113     if (!SliceV || SliceV->get()->getType() != Ty)
2114       return false;
2115     for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
2116       auto [V, Lane] = Item[Slice * NumElts + Elt];
2117       if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
2118         return false;
2119     }
2120   }
2121   return true;
2122 }
2123 
2124 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
2125                                   const SmallPtrSet<Use *, 4> &IdentityLeafs,
2126                                   const SmallPtrSet<Use *, 4> &SplatLeafs,
2127                                   const SmallPtrSet<Use *, 4> &ConcatLeafs,
2128                                   IRBuilder<> &Builder,
2129                                   const TargetTransformInfo *TTI) {
2130   auto [FrontU, FrontLane] = Item.front();
2131 
2132   if (IdentityLeafs.contains(FrontU)) {
2133     return FrontU->get();
2134   }
2135   if (SplatLeafs.contains(FrontU)) {
2136     SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
2137     return Builder.CreateShuffleVector(FrontU->get(), Mask);
2138   }
2139   if (ConcatLeafs.contains(FrontU)) {
2140     unsigned NumElts =
2141         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
2142     SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
2143     for (unsigned S = 0; S < Values.size(); ++S)
2144       Values[S] = Item[S * NumElts].first->get();
2145 
2146     while (Values.size() > 1) {
2147       NumElts *= 2;
2148       SmallVector<int, 16> Mask(NumElts, 0);
2149       std::iota(Mask.begin(), Mask.end(), 0);
2150       SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
2151       for (unsigned S = 0; S < NewValues.size(); ++S)
2152         NewValues[S] =
2153             Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
2154       Values = NewValues;
2155     }
2156     return Values[0];
2157   }
2158 
2159   auto *I = cast<Instruction>(FrontU->get());
2160   auto *II = dyn_cast<IntrinsicInst>(I);
2161   unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
2162   SmallVector<Value *> Ops(NumOps);
2163   for (unsigned Idx = 0; Idx < NumOps; Idx++) {
2164     if (II &&
2165         isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
2166       Ops[Idx] = II->getOperand(Idx);
2167       continue;
2168     }
2169     Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
2170                                    Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
2171                                    Builder, TTI);
2172   }
2173 
2174   SmallVector<Value *, 8> ValueList;
2175   for (const auto &Lane : Item)
2176     if (Lane.first)
2177       ValueList.push_back(Lane.first->get());
2178 
2179   Type *DstTy =
2180       FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
2181   if (auto *BI = dyn_cast<BinaryOperator>(I)) {
2182     auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
2183                                       Ops[0], Ops[1]);
2184     propagateIRFlags(Value, ValueList);
2185     return Value;
2186   }
2187   if (auto *CI = dyn_cast<CmpInst>(I)) {
2188     auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
2189     propagateIRFlags(Value, ValueList);
2190     return Value;
2191   }
2192   if (auto *SI = dyn_cast<SelectInst>(I)) {
2193     auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
2194     propagateIRFlags(Value, ValueList);
2195     return Value;
2196   }
2197   if (auto *CI = dyn_cast<CastInst>(I)) {
2198     auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
2199                                      Ops[0], DstTy);
2200     propagateIRFlags(Value, ValueList);
2201     return Value;
2202   }
2203   if (II) {
2204     auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
2205     propagateIRFlags(Value, ValueList);
2206     return Value;
2207   }
2208   assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
2209   auto *Value =
2210       Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
2211   propagateIRFlags(Value, ValueList);
2212   return Value;
2213 }
2214 
2215 // Starting from a shuffle, look up through operands tracking the shuffled index
2216 // of each lane. If we can simplify away the shuffles to identities then
2217 // do so.
2218 bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
2219   auto *Ty = dyn_cast<FixedVectorType>(I.getType());
2220   if (!Ty || I.use_empty())
2221     return false;
2222 
2223   SmallVector<InstLane> Start(Ty->getNumElements());
2224   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
2225     Start[M] = lookThroughShuffles(&*I.use_begin(), M);
2226 
2227   SmallVector<SmallVector<InstLane>> Worklist;
2228   Worklist.push_back(Start);
2229   SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
2230   unsigned NumVisited = 0;
2231 
2232   while (!Worklist.empty()) {
2233     if (++NumVisited > MaxInstrsToScan)
2234       return false;
2235 
2236     SmallVector<InstLane> Item = Worklist.pop_back_val();
2237     auto [FrontU, FrontLane] = Item.front();
2238 
2239     // If we found an undef first lane then bail out to keep things simple.
2240     if (!FrontU)
2241       return false;
2242 
2243     // Helper to peek through bitcasts to the same value.
2244     auto IsEquiv = [&](Value *X, Value *Y) {
2245       return X->getType() == Y->getType() &&
2246              peekThroughBitcasts(X) == peekThroughBitcasts(Y);
2247     };
2248 
2249     // Look for an identity value.
2250     if (FrontLane == 0 &&
2251         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() ==
2252             Ty->getNumElements() &&
2253         all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
2254           Value *FrontV = Item.front().first->get();
2255           return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
2256                                       E.value().second == (int)E.index());
2257         })) {
2258       IdentityLeafs.insert(FrontU);
2259       continue;
2260     }
2261     // Look for constants, for the moment only supporting constant splats.
2262     if (auto *C = dyn_cast<Constant>(FrontU);
2263         C && C->getSplatValue() &&
2264         all_of(drop_begin(Item), [Item](InstLane &IL) {
2265           Value *FrontV = Item.front().first->get();
2266           Use *U = IL.first;
2267           return !U || (isa<Constant>(U->get()) &&
2268                         cast<Constant>(U->get())->getSplatValue() ==
2269                             cast<Constant>(FrontV)->getSplatValue());
2270         })) {
2271       SplatLeafs.insert(FrontU);
2272       continue;
2273     }
2274     // Look for a splat value.
2275     if (all_of(drop_begin(Item), [Item](InstLane &IL) {
2276           auto [FrontU, FrontLane] = Item.front();
2277           auto [U, Lane] = IL;
2278           return !U || (U->get() == FrontU->get() && Lane == FrontLane);
2279         })) {
2280       SplatLeafs.insert(FrontU);
2281       continue;
2282     }
2283 
2284     // We need each element to be the same type of value, and check that each
2285     // element has a single use.
2286     auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
2287       Value *FrontV = Item.front().first->get();
2288       if (!IL.first)
2289         return true;
2290       Value *V = IL.first->get();
2291       if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
2292         return false;
2293       if (V->getValueID() != FrontV->getValueID())
2294         return false;
2295       if (auto *CI = dyn_cast<CmpInst>(V))
2296         if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
2297           return false;
2298       if (auto *CI = dyn_cast<CastInst>(V))
2299         if (CI->getSrcTy()->getScalarType() !=
2300             cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
2301           return false;
2302       if (auto *SI = dyn_cast<SelectInst>(V))
2303         if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
2304             SI->getOperand(0)->getType() !=
2305                 cast<SelectInst>(FrontV)->getOperand(0)->getType())
2306           return false;
2307       if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
2308         return false;
2309       auto *II = dyn_cast<IntrinsicInst>(V);
2310       return !II || (isa<IntrinsicInst>(FrontV) &&
2311                      II->getIntrinsicID() ==
2312                          cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
2313                      !II->hasOperandBundles());
2314     };
2315     if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
2316       // Check the operator is one that we support.
2317       if (isa<BinaryOperator, CmpInst>(FrontU)) {
2318         //  We exclude div/rem in case they hit UB from poison lanes.
2319         if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
2320             BO && BO->isIntDivRem())
2321           return false;
2322         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2323         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2324         continue;
2325       } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
2326                      FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) {
2327         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2328         continue;
2329       } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) {
2330         // TODO: Handle vector widening/narrowing bitcasts.
2331         auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
2332         auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
2333         if (DstTy && SrcTy &&
2334             SrcTy->getNumElements() == DstTy->getNumElements()) {
2335           Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2336           continue;
2337         }
2338       } else if (isa<SelectInst>(FrontU)) {
2339         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2340         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2341         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
2342         continue;
2343       } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
2344                  II && isTriviallyVectorizable(II->getIntrinsicID()) &&
2345                  !II->hasOperandBundles()) {
2346         for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
2347           if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
2348                                                  &TTI)) {
2349             if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
2350                   Value *FrontV = Item.front().first->get();
2351                   Use *U = IL.first;
2352                   return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
2353                                 cast<Instruction>(FrontV)->getOperand(Op));
2354                 }))
2355               return false;
2356             continue;
2357           }
2358           Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
2359         }
2360         continue;
2361       }
2362     }
2363 
2364     if (isFreeConcat(Item, CostKind, TTI)) {
2365       ConcatLeafs.insert(FrontU);
2366       continue;
2367     }
2368 
2369     return false;
2370   }
2371 
2372   if (NumVisited <= 1)
2373     return false;
2374 
2375   // If we got this far, we know the shuffles are superfluous and can be
2376   // removed. Scan through again and generate the new tree of instructions.
2377   Builder.SetInsertPoint(&I);
2378   Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
2379                                  ConcatLeafs, Builder, &TTI);
2380   replaceValue(I, *V);
2381   return true;
2382 }
2383 
2384 /// Given a commutative reduction, the order of the input lanes does not alter
2385 /// the results. We can use this to remove certain shuffles feeding the
2386 /// reduction, removing the need to shuffle at all.
2387 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2388   auto *II = dyn_cast<IntrinsicInst>(&I);
2389   if (!II)
2390     return false;
2391   switch (II->getIntrinsicID()) {
2392   case Intrinsic::vector_reduce_add:
2393   case Intrinsic::vector_reduce_mul:
2394   case Intrinsic::vector_reduce_and:
2395   case Intrinsic::vector_reduce_or:
2396   case Intrinsic::vector_reduce_xor:
2397   case Intrinsic::vector_reduce_smin:
2398   case Intrinsic::vector_reduce_smax:
2399   case Intrinsic::vector_reduce_umin:
2400   case Intrinsic::vector_reduce_umax:
2401     break;
2402   default:
2403     return false;
2404   }
2405 
2406   // Find all the inputs when looking through operations that do not alter the
2407   // lane order (binops, for example). Currently we look for a single shuffle,
2408   // and can ignore splat values.
2409   std::queue<Value *> Worklist;
2410   SmallPtrSet<Value *, 4> Visited;
2411   ShuffleVectorInst *Shuffle = nullptr;
2412   if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
2413     Worklist.push(Op);
2414 
2415   while (!Worklist.empty()) {
2416     Value *CV = Worklist.front();
2417     Worklist.pop();
2418     if (Visited.contains(CV))
2419       continue;
2420 
2421     // Splats don't change the order, so can be safely ignored.
2422     if (isSplatValue(CV))
2423       continue;
2424 
2425     Visited.insert(CV);
2426 
2427     if (auto *CI = dyn_cast<Instruction>(CV)) {
2428       if (CI->isBinaryOp()) {
2429         for (auto *Op : CI->operand_values())
2430           Worklist.push(Op);
2431         continue;
2432       } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
2433         if (Shuffle && Shuffle != SV)
2434           return false;
2435         Shuffle = SV;
2436         continue;
2437       }
2438     }
2439 
2440     // Anything else is currently an unknown node.
2441     return false;
2442   }
2443 
2444   if (!Shuffle)
2445     return false;
2446 
2447   // Check all uses of the binary ops and shuffles are also included in the
2448   // lane-invariant operations (Visited should be the list of lanewise
2449   // instructions, including the shuffle that we found).
2450   for (auto *V : Visited)
2451     for (auto *U : V->users())
2452       if (!Visited.contains(U) && U != &I)
2453         return false;
2454 
2455   FixedVectorType *VecType =
2456       dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
2457   if (!VecType)
2458     return false;
2459   FixedVectorType *ShuffleInputType =
2460       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
2461   if (!ShuffleInputType)
2462     return false;
2463   unsigned NumInputElts = ShuffleInputType->getNumElements();
2464 
2465   // Find the mask from sorting the lanes into order. This is most likely to
2466   // become a identity or concat mask. Undef elements are pushed to the end.
2467   SmallVector<int> ConcatMask;
2468   Shuffle->getShuffleMask(ConcatMask);
2469   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
2470   // In the case of a truncating shuffle it's possible for the mask
2471   // to have an index greater than the size of the resulting vector.
2472   // This requires special handling.
2473   bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
2474   bool UsesSecondVec =
2475       any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
2476 
2477   FixedVectorType *VecTyForCost =
2478       (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
2479   InstructionCost OldCost = TTI.getShuffleCost(
2480       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2481       VecTyForCost, Shuffle->getShuffleMask(), CostKind);
2482   InstructionCost NewCost = TTI.getShuffleCost(
2483       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2484       VecTyForCost, ConcatMask, CostKind);
2485 
2486   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
2487                     << "\n");
2488   LLVM_DEBUG(dbgs() << "  OldCost: " << OldCost << " vs NewCost: " << NewCost
2489                     << "\n");
2490   if (NewCost < OldCost) {
2491     Builder.SetInsertPoint(Shuffle);
2492     Value *NewShuffle = Builder.CreateShuffleVector(
2493         Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
2494     LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
2495     replaceValue(*Shuffle, *NewShuffle);
2496   }
2497 
2498   // See if we can re-use foldSelectShuffle, getting it to reduce the size of
2499   // the shuffle into a nicer order, as it can ignore the order of the shuffles.
2500   return foldSelectShuffle(*Shuffle, true);
2501 }
2502 
2503 /// Determine if its more efficient to fold:
2504 ///   reduce(trunc(x)) -> trunc(reduce(x)).
2505 ///   reduce(sext(x))  -> sext(reduce(x)).
2506 ///   reduce(zext(x))  -> zext(reduce(x)).
2507 bool VectorCombine::foldCastFromReductions(Instruction &I) {
2508   auto *II = dyn_cast<IntrinsicInst>(&I);
2509   if (!II)
2510     return false;
2511 
2512   bool TruncOnly = false;
2513   Intrinsic::ID IID = II->getIntrinsicID();
2514   switch (IID) {
2515   case Intrinsic::vector_reduce_add:
2516   case Intrinsic::vector_reduce_mul:
2517     TruncOnly = true;
2518     break;
2519   case Intrinsic::vector_reduce_and:
2520   case Intrinsic::vector_reduce_or:
2521   case Intrinsic::vector_reduce_xor:
2522     break;
2523   default:
2524     return false;
2525   }
2526 
2527   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
2528   Value *ReductionSrc = I.getOperand(0);
2529 
2530   Value *Src;
2531   if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
2532       (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
2533     return false;
2534 
2535   auto CastOpc =
2536       (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
2537 
2538   auto *SrcTy = cast<VectorType>(Src->getType());
2539   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
2540   Type *ResultTy = I.getType();
2541 
2542   InstructionCost OldCost = TTI.getArithmeticReductionCost(
2543       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2544   OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
2545                                   TTI::CastContextHint::None, CostKind,
2546                                   cast<CastInst>(ReductionSrc));
2547   InstructionCost NewCost =
2548       TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
2549                                      CostKind) +
2550       TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
2551                            TTI::CastContextHint::None, CostKind);
2552 
2553   if (OldCost <= NewCost || !NewCost.isValid())
2554     return false;
2555 
2556   Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
2557                                                 II->getIntrinsicID(), {Src});
2558   Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
2559   replaceValue(I, *NewCast);
2560   return true;
2561 }
2562 
2563 /// This method looks for groups of shuffles acting on binops, of the form:
2564 ///  %x = shuffle ...
2565 ///  %y = shuffle ...
2566 ///  %a = binop %x, %y
2567 ///  %b = binop %x, %y
2568 ///  shuffle %a, %b, selectmask
2569 /// We may, especially if the shuffle is wider than legal, be able to convert
2570 /// the shuffle to a form where only parts of a and b need to be computed. On
2571 /// architectures with no obvious "select" shuffle, this can reduce the total
2572 /// number of operations if the target reports them as cheaper.
2573 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
2574   auto *SVI = cast<ShuffleVectorInst>(&I);
2575   auto *VT = cast<FixedVectorType>(I.getType());
2576   auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
2577   auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
2578   if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
2579       VT != Op0->getType())
2580     return false;
2581 
2582   auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
2583   auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
2584   auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
2585   auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
2586   SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
2587   auto checkSVNonOpUses = [&](Instruction *I) {
2588     if (!I || I->getOperand(0)->getType() != VT)
2589       return true;
2590     return any_of(I->users(), [&](User *U) {
2591       return U != Op0 && U != Op1 &&
2592              !(isa<ShuffleVectorInst>(U) &&
2593                (InputShuffles.contains(cast<Instruction>(U)) ||
2594                 isInstructionTriviallyDead(cast<Instruction>(U))));
2595     });
2596   };
2597   if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
2598       checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
2599     return false;
2600 
2601   // Collect all the uses that are shuffles that we can transform together. We
2602   // may not have a single shuffle, but a group that can all be transformed
2603   // together profitably.
2604   SmallVector<ShuffleVectorInst *> Shuffles;
2605   auto collectShuffles = [&](Instruction *I) {
2606     for (auto *U : I->users()) {
2607       auto *SV = dyn_cast<ShuffleVectorInst>(U);
2608       if (!SV || SV->getType() != VT)
2609         return false;
2610       if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
2611           (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
2612         return false;
2613       if (!llvm::is_contained(Shuffles, SV))
2614         Shuffles.push_back(SV);
2615     }
2616     return true;
2617   };
2618   if (!collectShuffles(Op0) || !collectShuffles(Op1))
2619     return false;
2620   // From a reduction, we need to be processing a single shuffle, otherwise the
2621   // other uses will not be lane-invariant.
2622   if (FromReduction && Shuffles.size() > 1)
2623     return false;
2624 
2625   // Add any shuffle uses for the shuffles we have found, to include them in our
2626   // cost calculations.
2627   if (!FromReduction) {
2628     for (ShuffleVectorInst *SV : Shuffles) {
2629       for (auto *U : SV->users()) {
2630         ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
2631         if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
2632           Shuffles.push_back(SSV);
2633       }
2634     }
2635   }
2636 
2637   // For each of the output shuffles, we try to sort all the first vector
2638   // elements to the beginning, followed by the second array elements at the
2639   // end. If the binops are legalized to smaller vectors, this may reduce total
2640   // number of binops. We compute the ReconstructMask mask needed to convert
2641   // back to the original lane order.
2642   SmallVector<std::pair<int, int>> V1, V2;
2643   SmallVector<SmallVector<int>> OrigReconstructMasks;
2644   int MaxV1Elt = 0, MaxV2Elt = 0;
2645   unsigned NumElts = VT->getNumElements();
2646   for (ShuffleVectorInst *SVN : Shuffles) {
2647     SmallVector<int> Mask;
2648     SVN->getShuffleMask(Mask);
2649 
2650     // Check the operands are the same as the original, or reversed (in which
2651     // case we need to commute the mask).
2652     Value *SVOp0 = SVN->getOperand(0);
2653     Value *SVOp1 = SVN->getOperand(1);
2654     if (isa<UndefValue>(SVOp1)) {
2655       auto *SSV = cast<ShuffleVectorInst>(SVOp0);
2656       SVOp0 = SSV->getOperand(0);
2657       SVOp1 = SSV->getOperand(1);
2658       for (unsigned I = 0, E = Mask.size(); I != E; I++) {
2659         if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
2660           return false;
2661         Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
2662       }
2663     }
2664     if (SVOp0 == Op1 && SVOp1 == Op0) {
2665       std::swap(SVOp0, SVOp1);
2666       ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
2667     }
2668     if (SVOp0 != Op0 || SVOp1 != Op1)
2669       return false;
2670 
2671     // Calculate the reconstruction mask for this shuffle, as the mask needed to
2672     // take the packed values from Op0/Op1 and reconstructing to the original
2673     // order.
2674     SmallVector<int> ReconstructMask;
2675     for (unsigned I = 0; I < Mask.size(); I++) {
2676       if (Mask[I] < 0) {
2677         ReconstructMask.push_back(-1);
2678       } else if (Mask[I] < static_cast<int>(NumElts)) {
2679         MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
2680         auto It = find_if(V1, [&](const std::pair<int, int> &A) {
2681           return Mask[I] == A.first;
2682         });
2683         if (It != V1.end())
2684           ReconstructMask.push_back(It - V1.begin());
2685         else {
2686           ReconstructMask.push_back(V1.size());
2687           V1.emplace_back(Mask[I], V1.size());
2688         }
2689       } else {
2690         MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
2691         auto It = find_if(V2, [&](const std::pair<int, int> &A) {
2692           return Mask[I] - static_cast<int>(NumElts) == A.first;
2693         });
2694         if (It != V2.end())
2695           ReconstructMask.push_back(NumElts + It - V2.begin());
2696         else {
2697           ReconstructMask.push_back(NumElts + V2.size());
2698           V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
2699         }
2700       }
2701     }
2702 
2703     // For reductions, we know that the lane ordering out doesn't alter the
2704     // result. In-order can help simplify the shuffle away.
2705     if (FromReduction)
2706       sort(ReconstructMask);
2707     OrigReconstructMasks.push_back(std::move(ReconstructMask));
2708   }
2709 
2710   // If the Maximum element used from V1 and V2 are not larger than the new
2711   // vectors, the vectors are already packes and performing the optimization
2712   // again will likely not help any further. This also prevents us from getting
2713   // stuck in a cycle in case the costs do not also rule it out.
2714   if (V1.empty() || V2.empty() ||
2715       (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
2716        MaxV2Elt == static_cast<int>(V2.size()) - 1))
2717     return false;
2718 
2719   // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
2720   // shuffle of another shuffle, or not a shuffle (that is treated like a
2721   // identity shuffle).
2722   auto GetBaseMaskValue = [&](Instruction *I, int M) {
2723     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2724     if (!SV)
2725       return M;
2726     if (isa<UndefValue>(SV->getOperand(1)))
2727       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2728         if (InputShuffles.contains(SSV))
2729           return SSV->getMaskValue(SV->getMaskValue(M));
2730     return SV->getMaskValue(M);
2731   };
2732 
2733   // Attempt to sort the inputs my ascending mask values to make simpler input
2734   // shuffles and push complex shuffles down to the uses. We sort on the first
2735   // of the two input shuffle orders, to try and get at least one input into a
2736   // nice order.
2737   auto SortBase = [&](Instruction *A, std::pair<int, int> X,
2738                       std::pair<int, int> Y) {
2739     int MXA = GetBaseMaskValue(A, X.first);
2740     int MYA = GetBaseMaskValue(A, Y.first);
2741     return MXA < MYA;
2742   };
2743   stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
2744     return SortBase(SVI0A, A, B);
2745   });
2746   stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
2747     return SortBase(SVI1A, A, B);
2748   });
2749   // Calculate our ReconstructMasks from the OrigReconstructMasks and the
2750   // modified order of the input shuffles.
2751   SmallVector<SmallVector<int>> ReconstructMasks;
2752   for (const auto &Mask : OrigReconstructMasks) {
2753     SmallVector<int> ReconstructMask;
2754     for (int M : Mask) {
2755       auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
2756         auto It = find_if(V, [M](auto A) { return A.second == M; });
2757         assert(It != V.end() && "Expected all entries in Mask");
2758         return std::distance(V.begin(), It);
2759       };
2760       if (M < 0)
2761         ReconstructMask.push_back(-1);
2762       else if (M < static_cast<int>(NumElts)) {
2763         ReconstructMask.push_back(FindIndex(V1, M));
2764       } else {
2765         ReconstructMask.push_back(NumElts + FindIndex(V2, M));
2766       }
2767     }
2768     ReconstructMasks.push_back(std::move(ReconstructMask));
2769   }
2770 
2771   // Calculate the masks needed for the new input shuffles, which get padded
2772   // with undef
2773   SmallVector<int> V1A, V1B, V2A, V2B;
2774   for (unsigned I = 0; I < V1.size(); I++) {
2775     V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
2776     V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
2777   }
2778   for (unsigned I = 0; I < V2.size(); I++) {
2779     V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
2780     V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
2781   }
2782   while (V1A.size() < NumElts) {
2783     V1A.push_back(PoisonMaskElem);
2784     V1B.push_back(PoisonMaskElem);
2785   }
2786   while (V2A.size() < NumElts) {
2787     V2A.push_back(PoisonMaskElem);
2788     V2B.push_back(PoisonMaskElem);
2789   }
2790 
2791   auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
2792     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2793     if (!SV)
2794       return C;
2795     return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
2796                                       ? TTI::SK_PermuteSingleSrc
2797                                       : TTI::SK_PermuteTwoSrc,
2798                                   VT, SV->getShuffleMask(), CostKind);
2799   };
2800   auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
2801     return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind);
2802   };
2803 
2804   // Get the costs of the shuffles + binops before and after with the new
2805   // shuffle masks.
2806   InstructionCost CostBefore =
2807       TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
2808       TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
2809   CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
2810                                 InstructionCost(0), AddShuffleCost);
2811   CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
2812                                 InstructionCost(0), AddShuffleCost);
2813 
2814   // The new binops will be unused for lanes past the used shuffle lengths.
2815   // These types attempt to get the correct cost for that from the target.
2816   FixedVectorType *Op0SmallVT =
2817       FixedVectorType::get(VT->getScalarType(), V1.size());
2818   FixedVectorType *Op1SmallVT =
2819       FixedVectorType::get(VT->getScalarType(), V2.size());
2820   InstructionCost CostAfter =
2821       TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
2822       TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
2823   CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
2824                                InstructionCost(0), AddShuffleMaskCost);
2825   std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
2826   CostAfter +=
2827       std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
2828                       InstructionCost(0), AddShuffleMaskCost);
2829 
2830   LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
2831   LLVM_DEBUG(dbgs() << "  CostBefore: " << CostBefore
2832                     << " vs CostAfter: " << CostAfter << "\n");
2833   if (CostBefore <= CostAfter)
2834     return false;
2835 
2836   // The cost model has passed, create the new instructions.
2837   auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
2838     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2839     if (!SV)
2840       return I;
2841     if (isa<UndefValue>(SV->getOperand(1)))
2842       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2843         if (InputShuffles.contains(SSV))
2844           return SSV->getOperand(Op);
2845     return SV->getOperand(Op);
2846   };
2847   Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
2848   Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
2849                                              GetShuffleOperand(SVI0A, 1), V1A);
2850   Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
2851   Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
2852                                              GetShuffleOperand(SVI0B, 1), V1B);
2853   Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
2854   Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
2855                                              GetShuffleOperand(SVI1A, 1), V2A);
2856   Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
2857   Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
2858                                              GetShuffleOperand(SVI1B, 1), V2B);
2859   Builder.SetInsertPoint(Op0);
2860   Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
2861                                     NSV0A, NSV0B);
2862   if (auto *I = dyn_cast<Instruction>(NOp0))
2863     I->copyIRFlags(Op0, true);
2864   Builder.SetInsertPoint(Op1);
2865   Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
2866                                     NSV1A, NSV1B);
2867   if (auto *I = dyn_cast<Instruction>(NOp1))
2868     I->copyIRFlags(Op1, true);
2869 
2870   for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
2871     Builder.SetInsertPoint(Shuffles[S]);
2872     Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
2873     replaceValue(*Shuffles[S], *NSV);
2874   }
2875 
2876   Worklist.pushValue(NSV0A);
2877   Worklist.pushValue(NSV0B);
2878   Worklist.pushValue(NSV1A);
2879   Worklist.pushValue(NSV1B);
2880   for (auto *S : Shuffles)
2881     Worklist.add(S);
2882   return true;
2883 }
2884 
2885 /// Check if instruction depends on ZExt and this ZExt can be moved after the
2886 /// instruction. Move ZExt if it is profitable. For example:
2887 ///     logic(zext(x),y) -> zext(logic(x,trunc(y)))
2888 ///     lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
2889 /// Cost model calculations takes into account if zext(x) has other users and
2890 /// whether it can be propagated through them too.
2891 bool VectorCombine::shrinkType(Instruction &I) {
2892   Value *ZExted, *OtherOperand;
2893   if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
2894                                   m_Value(OtherOperand))) &&
2895       !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
2896     return false;
2897 
2898   Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
2899 
2900   auto *BigTy = cast<FixedVectorType>(I.getType());
2901   auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
2902   unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
2903 
2904   if (I.getOpcode() == Instruction::LShr) {
2905     // Check that the shift amount is less than the number of bits in the
2906     // smaller type. Otherwise, the smaller lshr will return a poison value.
2907     KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
2908     if (ShAmtKB.getMaxValue().uge(BW))
2909       return false;
2910   } else {
2911     // Check that the expression overall uses at most the same number of bits as
2912     // ZExted
2913     KnownBits KB = computeKnownBits(&I, *DL);
2914     if (KB.countMaxActiveBits() > BW)
2915       return false;
2916   }
2917 
2918   // Calculate costs of leaving current IR as it is and moving ZExt operation
2919   // later, along with adding truncates if needed
2920   InstructionCost ZExtCost = TTI.getCastInstrCost(
2921       Instruction::ZExt, BigTy, SmallTy,
2922       TargetTransformInfo::CastContextHint::None, CostKind);
2923   InstructionCost CurrentCost = ZExtCost;
2924   InstructionCost ShrinkCost = 0;
2925 
2926   // Calculate total cost and check that we can propagate through all ZExt users
2927   for (User *U : ZExtOperand->users()) {
2928     auto *UI = cast<Instruction>(U);
2929     if (UI == &I) {
2930       CurrentCost +=
2931           TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2932       ShrinkCost +=
2933           TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2934       ShrinkCost += ZExtCost;
2935       continue;
2936     }
2937 
2938     if (!Instruction::isBinaryOp(UI->getOpcode()))
2939       return false;
2940 
2941     // Check if we can propagate ZExt through its other users
2942     KnownBits KB = computeKnownBits(UI, *DL);
2943     if (KB.countMaxActiveBits() > BW)
2944       return false;
2945 
2946     CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2947     ShrinkCost +=
2948         TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2949     ShrinkCost += ZExtCost;
2950   }
2951 
2952   // If the other instruction operand is not a constant, we'll need to
2953   // generate a truncate instruction. So we have to adjust cost
2954   if (!isa<Constant>(OtherOperand))
2955     ShrinkCost += TTI.getCastInstrCost(
2956         Instruction::Trunc, SmallTy, BigTy,
2957         TargetTransformInfo::CastContextHint::None, CostKind);
2958 
2959   // If the cost of shrinking types and leaving the IR is the same, we'll lean
2960   // towards modifying the IR because shrinking opens opportunities for other
2961   // shrinking optimisations.
2962   if (ShrinkCost > CurrentCost)
2963     return false;
2964 
2965   Builder.SetInsertPoint(&I);
2966   Value *Op0 = ZExted;
2967   Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
2968   // Keep the order of operands the same
2969   if (I.getOperand(0) == OtherOperand)
2970     std::swap(Op0, Op1);
2971   Value *NewBinOp =
2972       Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
2973   cast<Instruction>(NewBinOp)->copyIRFlags(&I);
2974   cast<Instruction>(NewBinOp)->copyMetadata(I);
2975   Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
2976   replaceValue(I, *NewZExtr);
2977   return true;
2978 }
2979 
2980 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
2981 /// shuffle (DstVec, SrcVec, Mask)
2982 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
2983   Value *DstVec, *SrcVec;
2984   uint64_t ExtIdx, InsIdx;
2985   if (!match(&I,
2986              m_InsertElt(m_Value(DstVec),
2987                          m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
2988                          m_ConstantInt(InsIdx))))
2989     return false;
2990 
2991   auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
2992   if (!VecTy || SrcVec->getType() != VecTy)
2993     return false;
2994 
2995   unsigned NumElts = VecTy->getNumElements();
2996   if (ExtIdx >= NumElts || InsIdx >= NumElts)
2997     return false;
2998 
2999   SmallVector<int> Mask(NumElts, 0);
3000   std::iota(Mask.begin(), Mask.end(), 0);
3001   Mask[InsIdx] = ExtIdx + NumElts;
3002   // Cost
3003   auto *Ins = cast<InsertElementInst>(&I);
3004   auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
3005 
3006   InstructionCost OldCost =
3007       TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx) +
3008       TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx);
3009 
3010   InstructionCost NewCost = TTI.getShuffleCost(
3011       TargetTransformInfo::SK_PermuteTwoSrc, VecTy, Mask, CostKind);
3012   if (!Ext->hasOneUse())
3013     NewCost += TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx);
3014 
3015   LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I
3016                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
3017                     << "\n");
3018 
3019   if (OldCost < NewCost)
3020     return false;
3021 
3022   // Canonicalize undef param to RHS to help further folds.
3023   if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
3024     ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
3025     std::swap(DstVec, SrcVec);
3026   }
3027 
3028   Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
3029   replaceValue(I, *Shuf);
3030 
3031   return true;
3032 }
3033 
3034 /// This is the entry point for all transforms. Pass manager differences are
3035 /// handled in the callers of this function.
3036 bool VectorCombine::run() {
3037   if (DisableVectorCombine)
3038     return false;
3039 
3040   // Don't attempt vectorization if the target does not support vectors.
3041   if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
3042     return false;
3043 
3044   LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
3045 
3046   bool MadeChange = false;
3047   auto FoldInst = [this, &MadeChange](Instruction &I) {
3048     Builder.SetInsertPoint(&I);
3049     bool IsVectorType = isa<VectorType>(I.getType());
3050     bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
3051     auto Opcode = I.getOpcode();
3052 
3053     LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
3054 
3055     // These folds should be beneficial regardless of when this pass is run
3056     // in the optimization pipeline.
3057     // The type checking is for run-time efficiency. We can avoid wasting time
3058     // dispatching to folding functions if there's no chance of matching.
3059     if (IsFixedVectorType) {
3060       switch (Opcode) {
3061       case Instruction::InsertElement:
3062         MadeChange |= vectorizeLoadInsert(I);
3063         break;
3064       case Instruction::ShuffleVector:
3065         MadeChange |= widenSubvectorLoad(I);
3066         break;
3067       default:
3068         break;
3069       }
3070     }
3071 
3072     // This transform works with scalable and fixed vectors
3073     // TODO: Identify and allow other scalable transforms
3074     if (IsVectorType) {
3075       MadeChange |= scalarizeBinopOrCmp(I);
3076       MadeChange |= scalarizeLoadExtract(I);
3077       MadeChange |= scalarizeVPIntrinsic(I);
3078     }
3079 
3080     if (Opcode == Instruction::Store)
3081       MadeChange |= foldSingleElementStore(I);
3082 
3083     // If this is an early pipeline invocation of this pass, we are done.
3084     if (TryEarlyFoldsOnly)
3085       return;
3086 
3087     // Otherwise, try folds that improve codegen but may interfere with
3088     // early IR canonicalizations.
3089     // The type checking is for run-time efficiency. We can avoid wasting time
3090     // dispatching to folding functions if there's no chance of matching.
3091     if (IsFixedVectorType) {
3092       switch (Opcode) {
3093       case Instruction::InsertElement:
3094         MadeChange |= foldInsExtFNeg(I);
3095         MadeChange |= foldInsExtVectorToShuffle(I);
3096         break;
3097       case Instruction::ShuffleVector:
3098         MadeChange |= foldPermuteOfBinops(I);
3099         MadeChange |= foldShuffleOfBinops(I);
3100         MadeChange |= foldShuffleOfCastops(I);
3101         MadeChange |= foldShuffleOfShuffles(I);
3102         MadeChange |= foldShuffleOfIntrinsics(I);
3103         MadeChange |= foldSelectShuffle(I);
3104         MadeChange |= foldShuffleToIdentity(I);
3105         break;
3106       case Instruction::BitCast:
3107         MadeChange |= foldBitcastShuffle(I);
3108         break;
3109       default:
3110         MadeChange |= shrinkType(I);
3111         break;
3112       }
3113     } else {
3114       switch (Opcode) {
3115       case Instruction::Call:
3116         MadeChange |= foldShuffleFromReductions(I);
3117         MadeChange |= foldCastFromReductions(I);
3118         break;
3119       case Instruction::ICmp:
3120       case Instruction::FCmp:
3121         MadeChange |= foldExtractExtract(I);
3122         break;
3123       case Instruction::Or:
3124         MadeChange |= foldConcatOfBoolMasks(I);
3125         [[fallthrough]];
3126       default:
3127         if (Instruction::isBinaryOp(Opcode)) {
3128           MadeChange |= foldExtractExtract(I);
3129           MadeChange |= foldExtractedCmps(I);
3130         }
3131         break;
3132       }
3133     }
3134   };
3135 
3136   for (BasicBlock &BB : F) {
3137     // Ignore unreachable basic blocks.
3138     if (!DT.isReachableFromEntry(&BB))
3139       continue;
3140     // Use early increment range so that we can erase instructions in loop.
3141     for (Instruction &I : make_early_inc_range(BB)) {
3142       if (I.isDebugOrPseudoInst())
3143         continue;
3144       FoldInst(I);
3145     }
3146   }
3147 
3148   while (!Worklist.isEmpty()) {
3149     Instruction *I = Worklist.removeOne();
3150     if (!I)
3151       continue;
3152 
3153     if (isInstructionTriviallyDead(I)) {
3154       eraseInstruction(*I);
3155       continue;
3156     }
3157 
3158     FoldInst(*I);
3159   }
3160 
3161   return MadeChange;
3162 }
3163 
3164 PreservedAnalyses VectorCombinePass::run(Function &F,
3165                                          FunctionAnalysisManager &FAM) {
3166   auto &AC = FAM.getResult<AssumptionAnalysis>(F);
3167   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
3168   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
3169   AAResults &AA = FAM.getResult<AAManager>(F);
3170   const DataLayout *DL = &F.getDataLayout();
3171   VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
3172                          TryEarlyFoldsOnly);
3173   if (!Combiner.run())
3174     return PreservedAnalyses::all();
3175   PreservedAnalyses PA;
3176   PA.preserveSet<CFGAnalyses>();
3177   return PA;
3178 }
3179