xref: /llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp (revision d5a96eb125eaf661f7f9aad4cd184973fe750528)
1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/BasicAliasAnalysis.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/Loads.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/VectorUtils.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/PatternMatch.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include "llvm/Transforms/Utils/LoopUtils.h"
34 #include <numeric>
35 #include <queue>
36 
37 #define DEBUG_TYPE "vector-combine"
38 #include "llvm/Transforms/Utils/InstructionWorklist.h"
39 
40 using namespace llvm;
41 using namespace llvm::PatternMatch;
42 
43 STATISTIC(NumVecLoad, "Number of vector loads formed");
44 STATISTIC(NumVecCmp, "Number of vector compares formed");
45 STATISTIC(NumVecBO, "Number of vector binops formed");
46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
48 STATISTIC(NumScalarBO, "Number of scalar binops formed");
49 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
50 
51 static cl::opt<bool> DisableVectorCombine(
52     "disable-vector-combine", cl::init(false), cl::Hidden,
53     cl::desc("Disable all vector combine transforms"));
54 
55 static cl::opt<bool> DisableBinopExtractShuffle(
56     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
57     cl::desc("Disable binop extract to shuffle transforms"));
58 
59 static cl::opt<unsigned> MaxInstrsToScan(
60     "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
61     cl::desc("Max number of instructions to scan for vector combining."));
62 
63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
64 
65 namespace {
66 class VectorCombine {
67 public:
68   VectorCombine(Function &F, const TargetTransformInfo &TTI,
69                 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
70                 const DataLayout *DL, TTI::TargetCostKind CostKind,
71                 bool TryEarlyFoldsOnly)
72       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL),
73         CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
74 
75   bool run();
76 
77 private:
78   Function &F;
79   IRBuilder<> Builder;
80   const TargetTransformInfo &TTI;
81   const DominatorTree &DT;
82   AAResults &AA;
83   AssumptionCache &AC;
84   const DataLayout *DL;
85   TTI::TargetCostKind CostKind;
86 
87   /// If true, only perform beneficial early IR transforms. Do not introduce new
88   /// vector operations.
89   bool TryEarlyFoldsOnly;
90 
91   InstructionWorklist Worklist;
92 
93   // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
94   //       parameter. That should be updated to specific sub-classes because the
95   //       run loop was changed to dispatch on opcode.
96   bool vectorizeLoadInsert(Instruction &I);
97   bool widenSubvectorLoad(Instruction &I);
98   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
99                                         ExtractElementInst *Ext1,
100                                         unsigned PreferredExtractIndex) const;
101   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
102                              const Instruction &I,
103                              ExtractElementInst *&ConvertToShuffle,
104                              unsigned PreferredExtractIndex);
105   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
106                      Instruction &I);
107   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
108                        Instruction &I);
109   bool foldExtractExtract(Instruction &I);
110   bool foldInsExtFNeg(Instruction &I);
111   bool foldInsExtVectorToShuffle(Instruction &I);
112   bool foldBitcastShuffle(Instruction &I);
113   bool scalarizeBinopOrCmp(Instruction &I);
114   bool scalarizeVPIntrinsic(Instruction &I);
115   bool foldExtractedCmps(Instruction &I);
116   bool foldSingleElementStore(Instruction &I);
117   bool scalarizeLoadExtract(Instruction &I);
118   bool foldConcatOfBoolMasks(Instruction &I);
119   bool foldPermuteOfBinops(Instruction &I);
120   bool foldShuffleOfBinops(Instruction &I);
121   bool foldShuffleOfCastops(Instruction &I);
122   bool foldShuffleOfShuffles(Instruction &I);
123   bool foldShuffleOfIntrinsics(Instruction &I);
124   bool foldShuffleToIdentity(Instruction &I);
125   bool foldShuffleFromReductions(Instruction &I);
126   bool foldCastFromReductions(Instruction &I);
127   bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
128   bool shrinkType(Instruction &I);
129 
130   void replaceValue(Value &Old, Value &New) {
131     Old.replaceAllUsesWith(&New);
132     if (auto *NewI = dyn_cast<Instruction>(&New)) {
133       New.takeName(&Old);
134       Worklist.pushUsersToWorkList(*NewI);
135       Worklist.pushValue(NewI);
136     }
137     Worklist.pushValue(&Old);
138   }
139 
140   void eraseInstruction(Instruction &I) {
141     LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
142     for (Value *Op : I.operands())
143       Worklist.pushValue(Op);
144     Worklist.remove(&I);
145     I.eraseFromParent();
146   }
147 };
148 } // namespace
149 
150 /// Return the source operand of a potentially bitcasted value. If there is no
151 /// bitcast, return the input value itself.
152 static Value *peekThroughBitcasts(Value *V) {
153   while (auto *BitCast = dyn_cast<BitCastInst>(V))
154     V = BitCast->getOperand(0);
155   return V;
156 }
157 
158 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
159   // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
160   // The widened load may load data from dirty regions or create data races
161   // non-existent in the source.
162   if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
163       Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
164       mustSuppressSpeculation(*Load))
165     return false;
166 
167   // We are potentially transforming byte-sized (8-bit) memory accesses, so make
168   // sure we have all of our type-based constraints in place for this target.
169   Type *ScalarTy = Load->getType()->getScalarType();
170   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
171   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
172   if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
173       ScalarSize % 8 != 0)
174     return false;
175 
176   return true;
177 }
178 
179 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
180   // Match insert into fixed vector of scalar value.
181   // TODO: Handle non-zero insert index.
182   Value *Scalar;
183   if (!match(&I,
184              m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt())))
185     return false;
186 
187   // Optionally match an extract from another vector.
188   Value *X;
189   bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
190   if (!HasExtract)
191     X = Scalar;
192 
193   auto *Load = dyn_cast<LoadInst>(X);
194   if (!canWidenLoad(Load, TTI))
195     return false;
196 
197   Type *ScalarTy = Scalar->getType();
198   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
199   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
200 
201   // Check safety of replacing the scalar load with a larger vector load.
202   // We use minimal alignment (maximum flexibility) because we only care about
203   // the dereferenceable region. When calculating cost and creating a new op,
204   // we may use a larger value based on alignment attributes.
205   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
206   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
207 
208   unsigned MinVecNumElts = MinVectorSize / ScalarSize;
209   auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
210   unsigned OffsetEltIndex = 0;
211   Align Alignment = Load->getAlign();
212   if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
213                                    &DT)) {
214     // It is not safe to load directly from the pointer, but we can still peek
215     // through gep offsets and check if it safe to load from a base address with
216     // updated alignment. If it is, we can shuffle the element(s) into place
217     // after loading.
218     unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
219     APInt Offset(OffsetBitWidth, 0);
220     SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset);
221 
222     // We want to shuffle the result down from a high element of a vector, so
223     // the offset must be positive.
224     if (Offset.isNegative())
225       return false;
226 
227     // The offset must be a multiple of the scalar element to shuffle cleanly
228     // in the element's size.
229     uint64_t ScalarSizeInBytes = ScalarSize / 8;
230     if (Offset.urem(ScalarSizeInBytes) != 0)
231       return false;
232 
233     // If we load MinVecNumElts, will our target element still be loaded?
234     OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
235     if (OffsetEltIndex >= MinVecNumElts)
236       return false;
237 
238     if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
239                                      &DT))
240       return false;
241 
242     // Update alignment with offset value. Note that the offset could be negated
243     // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
244     // negation does not change the result of the alignment calculation.
245     Alignment = commonAlignment(Alignment, Offset.getZExtValue());
246   }
247 
248   // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
249   // Use the greater of the alignment on the load or its source pointer.
250   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
251   Type *LoadTy = Load->getType();
252   unsigned AS = Load->getPointerAddressSpace();
253   InstructionCost OldCost =
254       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
255   APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
256   OldCost +=
257       TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
258                                    /* Insert */ true, HasExtract, CostKind);
259 
260   // New pattern: load VecPtr
261   InstructionCost NewCost =
262       TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
263   // Optionally, we are shuffling the loaded vector element(s) into place.
264   // For the mask set everything but element 0 to undef to prevent poison from
265   // propagating from the extra loaded memory. This will also optionally
266   // shrink/grow the vector from the loaded size to the output size.
267   // We assume this operation has no cost in codegen if there was no offset.
268   // Note that we could use freeze to avoid poison problems, but then we might
269   // still need a shuffle to change the vector size.
270   auto *Ty = cast<FixedVectorType>(I.getType());
271   unsigned OutputNumElts = Ty->getNumElements();
272   SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
273   assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
274   Mask[0] = OffsetEltIndex;
275   if (OffsetEltIndex)
276     NewCost +=
277         TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind);
278 
279   // We can aggressively convert to the vector form because the backend can
280   // invert this transform if it does not result in a performance win.
281   if (OldCost < NewCost || !NewCost.isValid())
282     return false;
283 
284   // It is safe and potentially profitable to load a vector directly:
285   // inselt undef, load Scalar, 0 --> load VecPtr
286   IRBuilder<> Builder(Load);
287   Value *CastedPtr =
288       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
289   Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
290   VecLd = Builder.CreateShuffleVector(VecLd, Mask);
291 
292   replaceValue(I, *VecLd);
293   ++NumVecLoad;
294   return true;
295 }
296 
297 /// If we are loading a vector and then inserting it into a larger vector with
298 /// undefined elements, try to load the larger vector and eliminate the insert.
299 /// This removes a shuffle in IR and may allow combining of other loaded values.
300 bool VectorCombine::widenSubvectorLoad(Instruction &I) {
301   // Match subvector insert of fixed vector.
302   auto *Shuf = cast<ShuffleVectorInst>(&I);
303   if (!Shuf->isIdentityWithPadding())
304     return false;
305 
306   // Allow a non-canonical shuffle mask that is choosing elements from op1.
307   unsigned NumOpElts =
308       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
309   unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
310     return M >= (int)(NumOpElts);
311   });
312 
313   auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
314   if (!canWidenLoad(Load, TTI))
315     return false;
316 
317   // We use minimal alignment (maximum flexibility) because we only care about
318   // the dereferenceable region. When calculating cost and creating a new op,
319   // we may use a larger value based on alignment attributes.
320   auto *Ty = cast<FixedVectorType>(I.getType());
321   Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
322   assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
323   Align Alignment = Load->getAlign();
324   if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT))
325     return false;
326 
327   Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
328   Type *LoadTy = Load->getType();
329   unsigned AS = Load->getPointerAddressSpace();
330 
331   // Original pattern: insert_subvector (load PtrOp)
332   // This conservatively assumes that the cost of a subvector insert into an
333   // undef value is 0. We could add that cost if the cost model accurately
334   // reflects the real cost of that operation.
335   InstructionCost OldCost =
336       TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
337 
338   // New pattern: load PtrOp
339   InstructionCost NewCost =
340       TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
341 
342   // We can aggressively convert to the vector form because the backend can
343   // invert this transform if it does not result in a performance win.
344   if (OldCost < NewCost || !NewCost.isValid())
345     return false;
346 
347   IRBuilder<> Builder(Load);
348   Value *CastedPtr =
349       Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
350   Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
351   replaceValue(I, *VecLd);
352   ++NumVecLoad;
353   return true;
354 }
355 
356 /// Determine which, if any, of the inputs should be replaced by a shuffle
357 /// followed by extract from a different index.
358 ExtractElementInst *VectorCombine::getShuffleExtract(
359     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
360     unsigned PreferredExtractIndex = InvalidIndex) const {
361   auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
362   auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
363   assert(Index0C && Index1C && "Expected constant extract indexes");
364 
365   unsigned Index0 = Index0C->getZExtValue();
366   unsigned Index1 = Index1C->getZExtValue();
367 
368   // If the extract indexes are identical, no shuffle is needed.
369   if (Index0 == Index1)
370     return nullptr;
371 
372   Type *VecTy = Ext0->getVectorOperand()->getType();
373   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
374   InstructionCost Cost0 =
375       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
376   InstructionCost Cost1 =
377       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
378 
379   // If both costs are invalid no shuffle is needed
380   if (!Cost0.isValid() && !Cost1.isValid())
381     return nullptr;
382 
383   // We are extracting from 2 different indexes, so one operand must be shuffled
384   // before performing a vector operation and/or extract. The more expensive
385   // extract will be replaced by a shuffle.
386   if (Cost0 > Cost1)
387     return Ext0;
388   if (Cost1 > Cost0)
389     return Ext1;
390 
391   // If the costs are equal and there is a preferred extract index, shuffle the
392   // opposite operand.
393   if (PreferredExtractIndex == Index0)
394     return Ext1;
395   if (PreferredExtractIndex == Index1)
396     return Ext0;
397 
398   // Otherwise, replace the extract with the higher index.
399   return Index0 > Index1 ? Ext0 : Ext1;
400 }
401 
402 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
403 /// vector operation(s) followed by extract. Return true if the existing
404 /// instructions are cheaper than a vector alternative. Otherwise, return false
405 /// and if one of the extracts should be transformed to a shufflevector, set
406 /// \p ConvertToShuffle to that extract instruction.
407 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
408                                           ExtractElementInst *Ext1,
409                                           const Instruction &I,
410                                           ExtractElementInst *&ConvertToShuffle,
411                                           unsigned PreferredExtractIndex) {
412   auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
413   auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
414   assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
415 
416   unsigned Opcode = I.getOpcode();
417   Value *Ext0Src = Ext0->getVectorOperand();
418   Value *Ext1Src = Ext1->getVectorOperand();
419   Type *ScalarTy = Ext0->getType();
420   auto *VecTy = cast<VectorType>(Ext0Src->getType());
421   InstructionCost ScalarOpCost, VectorOpCost;
422 
423   // Get cost estimates for scalar and vector versions of the operation.
424   bool IsBinOp = Instruction::isBinaryOp(Opcode);
425   if (IsBinOp) {
426     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
427     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
428   } else {
429     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
430            "Expected a compare");
431     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
432     ScalarOpCost = TTI.getCmpSelInstrCost(
433         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
434     VectorOpCost = TTI.getCmpSelInstrCost(
435         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
436   }
437 
438   // Get cost estimates for the extract elements. These costs will factor into
439   // both sequences.
440   unsigned Ext0Index = Ext0IndexC->getZExtValue();
441   unsigned Ext1Index = Ext1IndexC->getZExtValue();
442 
443   InstructionCost Extract0Cost =
444       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
445   InstructionCost Extract1Cost =
446       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
447 
448   // A more expensive extract will always be replaced by a splat shuffle.
449   // For example, if Ext0 is more expensive:
450   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
451   // extelt (opcode (splat V0, Ext0), V1), Ext1
452   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
453   //       check the cost of creating a broadcast shuffle and shuffling both
454   //       operands to element 0.
455   unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
456   unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
457   InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
458 
459   // Extra uses of the extracts mean that we include those costs in the
460   // vector total because those instructions will not be eliminated.
461   InstructionCost OldCost, NewCost;
462   if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
463     // Handle a special case. If the 2 extracts are identical, adjust the
464     // formulas to account for that. The extra use charge allows for either the
465     // CSE'd pattern or an unoptimized form with identical values:
466     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
467     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
468                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
469     OldCost = CheapExtractCost + ScalarOpCost;
470     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
471   } else {
472     // Handle the general case. Each extract is actually a different value:
473     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
474     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
475     NewCost = VectorOpCost + CheapExtractCost +
476               !Ext0->hasOneUse() * Extract0Cost +
477               !Ext1->hasOneUse() * Extract1Cost;
478   }
479 
480   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
481   if (ConvertToShuffle) {
482     if (IsBinOp && DisableBinopExtractShuffle)
483       return true;
484 
485     // If we are extracting from 2 different indexes, then one operand must be
486     // shuffled before performing the vector operation. The shuffle mask is
487     // poison except for 1 lane that is being translated to the remaining
488     // extraction lane. Therefore, it is a splat shuffle. Ex:
489     // ShufMask = { poison, poison, 0, poison }
490     // TODO: The cost model has an option for a "broadcast" shuffle
491     //       (splat-from-element-0), but no option for a more general splat.
492     if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
493       SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
494                                    PoisonMaskElem);
495       ShuffleMask[BestInsIndex] = BestExtIndex;
496       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
497                                     VecTy, ShuffleMask, CostKind, 0, nullptr,
498                                     {ConvertToShuffle});
499     } else {
500       NewCost +=
501           TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
502                              {}, CostKind, 0, nullptr, {ConvertToShuffle});
503     }
504   }
505 
506   // Aggressively form a vector op if the cost is equal because the transform
507   // may enable further optimization.
508   // Codegen can reverse this transform (scalarize) if it was not profitable.
509   return OldCost < NewCost;
510 }
511 
512 /// Create a shuffle that translates (shifts) 1 element from the input vector
513 /// to a new element location.
514 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
515                                  unsigned NewIndex, IRBuilder<> &Builder) {
516   // The shuffle mask is poison except for 1 lane that is being translated
517   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
518   // ShufMask = { 2, poison, poison, poison }
519   auto *VecTy = cast<FixedVectorType>(Vec->getType());
520   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
521   ShufMask[NewIndex] = OldIndex;
522   return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
523 }
524 
525 /// Given an extract element instruction with constant index operand, shuffle
526 /// the source vector (shift the scalar element) to a NewIndex for extraction.
527 /// Return null if the input can be constant folded, so that we are not creating
528 /// unnecessary instructions.
529 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
530                                             unsigned NewIndex,
531                                             IRBuilder<> &Builder) {
532   // Shufflevectors can only be created for fixed-width vectors.
533   Value *X = ExtElt->getVectorOperand();
534   if (!isa<FixedVectorType>(X->getType()))
535     return nullptr;
536 
537   // If the extract can be constant-folded, this code is unsimplified. Defer
538   // to other passes to handle that.
539   Value *C = ExtElt->getIndexOperand();
540   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
541   if (isa<Constant>(X))
542     return nullptr;
543 
544   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
545                                    NewIndex, Builder);
546   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
547 }
548 
549 /// Try to reduce extract element costs by converting scalar compares to vector
550 /// compares followed by extract.
551 /// cmp (ext0 V0, C), (ext1 V1, C)
552 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
553                                   ExtractElementInst *Ext1, Instruction &I) {
554   assert(isa<CmpInst>(&I) && "Expected a compare");
555   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
556              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
557          "Expected matching constant extract indexes");
558 
559   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
560   ++NumVecCmp;
561   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
562   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
563   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
564   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
565   replaceValue(I, *NewExt);
566 }
567 
568 /// Try to reduce extract element costs by converting scalar binops to vector
569 /// binops followed by extract.
570 /// bo (ext0 V0, C), (ext1 V1, C)
571 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
572                                     ExtractElementInst *Ext1, Instruction &I) {
573   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
574   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
575              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
576          "Expected matching constant extract indexes");
577 
578   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
579   ++NumVecBO;
580   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
581   Value *VecBO =
582       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
583 
584   // All IR flags are safe to back-propagate because any potential poison
585   // created in unused vector elements is discarded by the extract.
586   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
587     VecBOInst->copyIRFlags(&I);
588 
589   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
590   replaceValue(I, *NewExt);
591 }
592 
593 /// Match an instruction with extracted vector operands.
594 bool VectorCombine::foldExtractExtract(Instruction &I) {
595   // It is not safe to transform things like div, urem, etc. because we may
596   // create undefined behavior when executing those on unknown vector elements.
597   if (!isSafeToSpeculativelyExecute(&I))
598     return false;
599 
600   Instruction *I0, *I1;
601   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
602   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
603       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
604     return false;
605 
606   Value *V0, *V1;
607   uint64_t C0, C1;
608   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
609       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
610       V0->getType() != V1->getType())
611     return false;
612 
613   // If the scalar value 'I' is going to be re-inserted into a vector, then try
614   // to create an extract to that same element. The extract/insert can be
615   // reduced to a "select shuffle".
616   // TODO: If we add a larger pattern match that starts from an insert, this
617   //       probably becomes unnecessary.
618   auto *Ext0 = cast<ExtractElementInst>(I0);
619   auto *Ext1 = cast<ExtractElementInst>(I1);
620   uint64_t InsertIndex = InvalidIndex;
621   if (I.hasOneUse())
622     match(I.user_back(),
623           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
624 
625   ExtractElementInst *ExtractToChange;
626   if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
627     return false;
628 
629   if (ExtractToChange) {
630     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
631     ExtractElementInst *NewExtract =
632         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
633     if (!NewExtract)
634       return false;
635     if (ExtractToChange == Ext0)
636       Ext0 = NewExtract;
637     else
638       Ext1 = NewExtract;
639   }
640 
641   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
642     foldExtExtCmp(Ext0, Ext1, I);
643   else
644     foldExtExtBinop(Ext0, Ext1, I);
645 
646   Worklist.push(Ext0);
647   Worklist.push(Ext1);
648   return true;
649 }
650 
651 /// Try to replace an extract + scalar fneg + insert with a vector fneg +
652 /// shuffle.
653 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
654   // Match an insert (op (extract)) pattern.
655   Value *DestVec;
656   uint64_t Index;
657   Instruction *FNeg;
658   if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
659                              m_ConstantInt(Index))))
660     return false;
661 
662   // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
663   Value *SrcVec;
664   Instruction *Extract;
665   if (!match(FNeg, m_FNeg(m_CombineAnd(
666                        m_Instruction(Extract),
667                        m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
668     return false;
669 
670   auto *VecTy = cast<FixedVectorType>(I.getType());
671   auto *ScalarTy = VecTy->getScalarType();
672   auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
673   if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
674     return false;
675 
676   // Ignore bogus insert/extract index.
677   unsigned NumElts = VecTy->getNumElements();
678   if (Index >= NumElts)
679     return false;
680 
681   // We are inserting the negated element into the same lane that we extracted
682   // from. This is equivalent to a select-shuffle that chooses all but the
683   // negated element from the destination vector.
684   SmallVector<int> Mask(NumElts);
685   std::iota(Mask.begin(), Mask.end(), 0);
686   Mask[Index] = Index + NumElts;
687   InstructionCost OldCost =
688       TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
689       TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
690 
691   // If the extract has one use, it will be eliminated, so count it in the
692   // original cost. If it has more than one use, ignore the cost because it will
693   // be the same before/after.
694   if (Extract->hasOneUse())
695     OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
696 
697   InstructionCost NewCost =
698       TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
699       TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind);
700 
701   bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
702   // If the lengths of the two vectors are not equal,
703   // we need to add a length-change vector. Add this cost.
704   SmallVector<int> SrcMask;
705   if (NeedLenChg) {
706     SrcMask.assign(NumElts, PoisonMaskElem);
707     SrcMask[Index] = Index;
708     NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
709                                   SrcVecTy, SrcMask, CostKind);
710   }
711 
712   if (NewCost > OldCost)
713     return false;
714 
715   Value *NewShuf;
716   // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
717   Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
718   if (NeedLenChg) {
719     // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
720     Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
721     NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask);
722   } else {
723     // shuffle DestVec, (fneg SrcVec), Mask
724     NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
725   }
726 
727   replaceValue(I, *NewShuf);
728   return true;
729 }
730 
731 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
732 /// destination type followed by shuffle. This can enable further transforms by
733 /// moving bitcasts or shuffles together.
734 bool VectorCombine::foldBitcastShuffle(Instruction &I) {
735   Value *V0, *V1;
736   ArrayRef<int> Mask;
737   if (!match(&I, m_BitCast(m_OneUse(
738                      m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
739     return false;
740 
741   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
742   // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
743   // mask for scalable type is a splat or not.
744   // 2) Disallow non-vector casts.
745   // TODO: We could allow any shuffle.
746   auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
747   auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
748   if (!DestTy || !SrcTy)
749     return false;
750 
751   unsigned DestEltSize = DestTy->getScalarSizeInBits();
752   unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
753   if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
754     return false;
755 
756   bool IsUnary = isa<UndefValue>(V1);
757 
758   // For binary shuffles, only fold bitcast(shuffle(X,Y))
759   // if it won't increase the number of bitcasts.
760   if (!IsUnary) {
761     auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
762     auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
763     if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
764         !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
765       return false;
766   }
767 
768   SmallVector<int, 16> NewMask;
769   if (DestEltSize <= SrcEltSize) {
770     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
771     // always be expanded to the equivalent form choosing narrower elements.
772     assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
773     unsigned ScaleFactor = SrcEltSize / DestEltSize;
774     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
775   } else {
776     // The bitcast is from narrow elements to wide elements. The shuffle mask
777     // must choose consecutive elements to allow casting first.
778     assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
779     unsigned ScaleFactor = DestEltSize / SrcEltSize;
780     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
781       return false;
782   }
783 
784   // Bitcast the shuffle src - keep its original width but using the destination
785   // scalar type.
786   unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
787   auto *NewShuffleTy =
788       FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
789   auto *OldShuffleTy =
790       FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
791   unsigned NumOps = IsUnary ? 1 : 2;
792 
793   // The new shuffle must not cost more than the old shuffle.
794   TargetTransformInfo::ShuffleKind SK =
795       IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
796               : TargetTransformInfo::SK_PermuteTwoSrc;
797 
798   InstructionCost NewCost =
799       TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) +
800       (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
801                                      TargetTransformInfo::CastContextHint::None,
802                                      CostKind));
803   InstructionCost OldCost =
804       TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) +
805       TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
806                            TargetTransformInfo::CastContextHint::None,
807                            CostKind);
808 
809   LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n  OldCost: "
810                     << OldCost << " vs NewCost: " << NewCost << "\n");
811 
812   if (NewCost > OldCost || !NewCost.isValid())
813     return false;
814 
815   // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
816   ++NumShufOfBitcast;
817   Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
818   Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
819   Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
820   replaceValue(I, *Shuf);
821   return true;
822 }
823 
824 /// VP Intrinsics whose vector operands are both splat values may be simplified
825 /// into the scalar version of the operation and the result splatted. This
826 /// can lead to scalarization down the line.
827 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
828   if (!isa<VPIntrinsic>(I))
829     return false;
830   VPIntrinsic &VPI = cast<VPIntrinsic>(I);
831   Value *Op0 = VPI.getArgOperand(0);
832   Value *Op1 = VPI.getArgOperand(1);
833 
834   if (!isSplatValue(Op0) || !isSplatValue(Op1))
835     return false;
836 
837   // Check getSplatValue early in this function, to avoid doing unnecessary
838   // work.
839   Value *ScalarOp0 = getSplatValue(Op0);
840   Value *ScalarOp1 = getSplatValue(Op1);
841   if (!ScalarOp0 || !ScalarOp1)
842     return false;
843 
844   // For the binary VP intrinsics supported here, the result on disabled lanes
845   // is a poison value. For now, only do this simplification if all lanes
846   // are active.
847   // TODO: Relax the condition that all lanes are active by using insertelement
848   // on inactive lanes.
849   auto IsAllTrueMask = [](Value *MaskVal) {
850     if (Value *SplattedVal = getSplatValue(MaskVal))
851       if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
852         return ConstValue->isAllOnesValue();
853     return false;
854   };
855   if (!IsAllTrueMask(VPI.getArgOperand(2)))
856     return false;
857 
858   // Check to make sure we support scalarization of the intrinsic
859   Intrinsic::ID IntrID = VPI.getIntrinsicID();
860   if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
861     return false;
862 
863   // Calculate cost of splatting both operands into vectors and the vector
864   // intrinsic
865   VectorType *VecTy = cast<VectorType>(VPI.getType());
866   SmallVector<int> Mask;
867   if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
868     Mask.resize(FVTy->getNumElements(), 0);
869   InstructionCost SplatCost =
870       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
871       TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask,
872                          CostKind);
873 
874   // Calculate the cost of the VP Intrinsic
875   SmallVector<Type *, 4> Args;
876   for (Value *V : VPI.args())
877     Args.push_back(V->getType());
878   IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
879   InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
880   InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
881 
882   // Determine scalar opcode
883   std::optional<unsigned> FunctionalOpcode =
884       VPI.getFunctionalOpcode();
885   std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
886   if (!FunctionalOpcode) {
887     ScalarIntrID = VPI.getFunctionalIntrinsicID();
888     if (!ScalarIntrID)
889       return false;
890   }
891 
892   // Calculate cost of scalarizing
893   InstructionCost ScalarOpCost = 0;
894   if (ScalarIntrID) {
895     IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
896     ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
897   } else {
898     ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
899                                               VecTy->getScalarType(), CostKind);
900   }
901 
902   // The existing splats may be kept around if other instructions use them.
903   InstructionCost CostToKeepSplats =
904       (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
905   InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
906 
907   LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
908                     << "\n");
909   LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
910                     << ", Cost of scalarizing:" << NewCost << "\n");
911 
912   // We want to scalarize unless the vector variant actually has lower cost.
913   if (OldCost < NewCost || !NewCost.isValid())
914     return false;
915 
916   // Scalarize the intrinsic
917   ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
918   Value *EVL = VPI.getArgOperand(3);
919 
920   // If the VP op might introduce UB or poison, we can scalarize it provided
921   // that we know the EVL > 0: If the EVL is zero, then the original VP op
922   // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
923   // scalarizing it.
924   bool SafeToSpeculate;
925   if (ScalarIntrID)
926     SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID)
927                           .hasFnAttr(Attribute::AttrKind::Speculatable);
928   else
929     SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
930         *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
931   if (!SafeToSpeculate &&
932       !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI)))
933     return false;
934 
935   Value *ScalarVal =
936       ScalarIntrID
937           ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
938                                     {ScalarOp0, ScalarOp1})
939           : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
940                                 ScalarOp0, ScalarOp1);
941 
942   replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
943   return true;
944 }
945 
946 /// Match a vector binop or compare instruction with at least one inserted
947 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
948 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
949   CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
950   Value *Ins0, *Ins1;
951   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
952       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
953     return false;
954 
955   // Do not convert the vector condition of a vector select into a scalar
956   // condition. That may cause problems for codegen because of differences in
957   // boolean formats and register-file transfers.
958   // TODO: Can we account for that in the cost model?
959   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
960   if (IsCmp)
961     for (User *U : I.users())
962       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
963         return false;
964 
965   // Match against one or both scalar values being inserted into constant
966   // vectors:
967   // vec_op VecC0, (inselt VecC1, V1, Index)
968   // vec_op (inselt VecC0, V0, Index), VecC1
969   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
970   // TODO: Deal with mismatched index constants and variable indexes?
971   Constant *VecC0 = nullptr, *VecC1 = nullptr;
972   Value *V0 = nullptr, *V1 = nullptr;
973   uint64_t Index0 = 0, Index1 = 0;
974   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
975                                m_ConstantInt(Index0))) &&
976       !match(Ins0, m_Constant(VecC0)))
977     return false;
978   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
979                                m_ConstantInt(Index1))) &&
980       !match(Ins1, m_Constant(VecC1)))
981     return false;
982 
983   bool IsConst0 = !V0;
984   bool IsConst1 = !V1;
985   if (IsConst0 && IsConst1)
986     return false;
987   if (!IsConst0 && !IsConst1 && Index0 != Index1)
988     return false;
989 
990   auto *VecTy0 = cast<VectorType>(Ins0->getType());
991   auto *VecTy1 = cast<VectorType>(Ins1->getType());
992   if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
993       VecTy1->getElementCount().getKnownMinValue() <= Index1)
994     return false;
995 
996   // Bail for single insertion if it is a load.
997   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
998   auto *I0 = dyn_cast_or_null<Instruction>(V0);
999   auto *I1 = dyn_cast_or_null<Instruction>(V1);
1000   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
1001       (IsConst1 && I0 && I0->mayReadFromMemory()))
1002     return false;
1003 
1004   uint64_t Index = IsConst0 ? Index1 : Index0;
1005   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
1006   Type *VecTy = I.getType();
1007   assert(VecTy->isVectorTy() &&
1008          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
1009          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1010           ScalarTy->isPointerTy()) &&
1011          "Unexpected types for insert element into binop or cmp");
1012 
1013   unsigned Opcode = I.getOpcode();
1014   InstructionCost ScalarOpCost, VectorOpCost;
1015   if (IsCmp) {
1016     CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
1017     ScalarOpCost = TTI.getCmpSelInstrCost(
1018         Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1019     VectorOpCost = TTI.getCmpSelInstrCost(
1020         Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1021   } else {
1022     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1023     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1024   }
1025 
1026   // Get cost estimate for the insert element. This cost will factor into
1027   // both sequences.
1028   InstructionCost InsertCost = TTI.getVectorInstrCost(
1029       Instruction::InsertElement, VecTy, CostKind, Index);
1030   InstructionCost OldCost =
1031       (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost;
1032   InstructionCost NewCost = ScalarOpCost + InsertCost +
1033                             (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
1034                             (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
1035 
1036   // We want to scalarize unless the vector variant actually has lower cost.
1037   if (OldCost < NewCost || !NewCost.isValid())
1038     return false;
1039 
1040   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1041   // inselt NewVecC, (scalar_op V0, V1), Index
1042   if (IsCmp)
1043     ++NumScalarCmp;
1044   else
1045     ++NumScalarBO;
1046 
1047   // For constant cases, extract the scalar element, this should constant fold.
1048   if (IsConst0)
1049     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1050   if (IsConst1)
1051     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1052 
1053   Value *Scalar =
1054       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
1055             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
1056 
1057   Scalar->setName(I.getName() + ".scalar");
1058 
1059   // All IR flags are safe to back-propagate. There is no potential for extra
1060   // poison to be created by the scalar instruction.
1061   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1062     ScalarInst->copyIRFlags(&I);
1063 
1064   // Fold the vector constants in the original vectors into a new base vector.
1065   Value *NewVecC =
1066       IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1)
1067             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1068   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1069   replaceValue(I, *Insert);
1070   return true;
1071 }
1072 
1073 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1074 /// a vector into vector operations followed by extract. Note: The SLP pass
1075 /// may miss this pattern because of implementation problems.
1076 bool VectorCombine::foldExtractedCmps(Instruction &I) {
1077   auto *BI = dyn_cast<BinaryOperator>(&I);
1078 
1079   // We are looking for a scalar binop of booleans.
1080   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1081   if (!BI || !I.getType()->isIntegerTy(1))
1082     return false;
1083 
1084   // The compare predicates should match, and each compare should have a
1085   // constant operand.
1086   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1087   Instruction *I0, *I1;
1088   Constant *C0, *C1;
1089   CmpPredicate P0, P1;
1090   // FIXME: Use CmpPredicate::getMatching here.
1091   if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1092       !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) ||
1093       P0 != static_cast<CmpInst::Predicate>(P1))
1094     return false;
1095 
1096   // The compare operands must be extracts of the same vector with constant
1097   // extract indexes.
1098   Value *X;
1099   uint64_t Index0, Index1;
1100   if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1101       !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1102     return false;
1103 
1104   auto *Ext0 = cast<ExtractElementInst>(I0);
1105   auto *Ext1 = cast<ExtractElementInst>(I1);
1106   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1107   if (!ConvertToShuf)
1108     return false;
1109   assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1110          "Unknown ExtractElementInst");
1111 
1112   // The original scalar pattern is:
1113   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1114   CmpInst::Predicate Pred = P0;
1115   unsigned CmpOpcode =
1116       CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1117   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1118   if (!VecTy)
1119     return false;
1120 
1121   InstructionCost Ext0Cost =
1122       TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1123   InstructionCost Ext1Cost =
1124       TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1125   InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1126       CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1127       CostKind);
1128 
1129   InstructionCost OldCost =
1130       Ext0Cost + Ext1Cost + CmpCost * 2 +
1131       TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1132 
1133   // The proposed vector pattern is:
1134   // vcmp = cmp Pred X, VecC
1135   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1136   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1137   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1138   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
1139   InstructionCost NewCost = TTI.getCmpSelInstrCost(
1140       CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred,
1141       CostKind);
1142   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1143   ShufMask[CheapIndex] = ExpensiveIndex;
1144   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1145                                 ShufMask, CostKind);
1146   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1147   NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1148   NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1149   NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1150 
1151   // Aggressively form vector ops if the cost is equal because the transform
1152   // may enable further optimization.
1153   // Codegen can reverse this transform (scalarize) if it was not profitable.
1154   if (OldCost < NewCost || !NewCost.isValid())
1155     return false;
1156 
1157   // Create a vector constant from the 2 scalar constants.
1158   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1159                                    PoisonValue::get(VecTy->getElementType()));
1160   CmpC[Index0] = C0;
1161   CmpC[Index1] = C1;
1162   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1163   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1164   Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1165   Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1166   Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1167   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1168   replaceValue(I, *NewExt);
1169   ++NumVecCmpBO;
1170   return true;
1171 }
1172 
1173 // Check if memory loc modified between two instrs in the same BB
1174 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1175                                  BasicBlock::iterator End,
1176                                  const MemoryLocation &Loc, AAResults &AA) {
1177   unsigned NumScanned = 0;
1178   return std::any_of(Begin, End, [&](const Instruction &Instr) {
1179     return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1180            ++NumScanned > MaxInstrsToScan;
1181   });
1182 }
1183 
1184 namespace {
1185 /// Helper class to indicate whether a vector index can be safely scalarized and
1186 /// if a freeze needs to be inserted.
1187 class ScalarizationResult {
1188   enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1189 
1190   StatusTy Status;
1191   Value *ToFreeze;
1192 
1193   ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1194       : Status(Status), ToFreeze(ToFreeze) {}
1195 
1196 public:
1197   ScalarizationResult(const ScalarizationResult &Other) = default;
1198   ~ScalarizationResult() {
1199     assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1200   }
1201 
1202   static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1203   static ScalarizationResult safe() { return {StatusTy::Safe}; }
1204   static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1205     return {StatusTy::SafeWithFreeze, ToFreeze};
1206   }
1207 
1208   /// Returns true if the index can be scalarize without requiring a freeze.
1209   bool isSafe() const { return Status == StatusTy::Safe; }
1210   /// Returns true if the index cannot be scalarized.
1211   bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1212   /// Returns true if the index can be scalarize, but requires inserting a
1213   /// freeze.
1214   bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1215 
1216   /// Reset the state of Unsafe and clear ToFreze if set.
1217   void discard() {
1218     ToFreeze = nullptr;
1219     Status = StatusTy::Unsafe;
1220   }
1221 
1222   /// Freeze the ToFreeze and update the use in \p User to use it.
1223   void freeze(IRBuilder<> &Builder, Instruction &UserI) {
1224     assert(isSafeWithFreeze() &&
1225            "should only be used when freezing is required");
1226     assert(is_contained(ToFreeze->users(), &UserI) &&
1227            "UserI must be a user of ToFreeze");
1228     IRBuilder<>::InsertPointGuard Guard(Builder);
1229     Builder.SetInsertPoint(cast<Instruction>(&UserI));
1230     Value *Frozen =
1231         Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1232     for (Use &U : make_early_inc_range((UserI.operands())))
1233       if (U.get() == ToFreeze)
1234         U.set(Frozen);
1235 
1236     ToFreeze = nullptr;
1237   }
1238 };
1239 } // namespace
1240 
1241 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1242 /// Idx. \p Idx must access a valid vector element.
1243 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1244                                               Instruction *CtxI,
1245                                               AssumptionCache &AC,
1246                                               const DominatorTree &DT) {
1247   // We do checks for both fixed vector types and scalable vector types.
1248   // This is the number of elements of fixed vector types,
1249   // or the minimum number of elements of scalable vector types.
1250   uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1251 
1252   if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1253     if (C->getValue().ult(NumElements))
1254       return ScalarizationResult::safe();
1255     return ScalarizationResult::unsafe();
1256   }
1257 
1258   unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1259   APInt Zero(IntWidth, 0);
1260   APInt MaxElts(IntWidth, NumElements);
1261   ConstantRange ValidIndices(Zero, MaxElts);
1262   ConstantRange IdxRange(IntWidth, true);
1263 
1264   if (isGuaranteedNotToBePoison(Idx, &AC)) {
1265     if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
1266                                                    true, &AC, CtxI, &DT)))
1267       return ScalarizationResult::safe();
1268     return ScalarizationResult::unsafe();
1269   }
1270 
1271   // If the index may be poison, check if we can insert a freeze before the
1272   // range of the index is restricted.
1273   Value *IdxBase;
1274   ConstantInt *CI;
1275   if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1276     IdxRange = IdxRange.binaryAnd(CI->getValue());
1277   } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1278     IdxRange = IdxRange.urem(CI->getValue());
1279   }
1280 
1281   if (ValidIndices.contains(IdxRange))
1282     return ScalarizationResult::safeWithFreeze(IdxBase);
1283   return ScalarizationResult::unsafe();
1284 }
1285 
1286 /// The memory operation on a vector of \p ScalarType had alignment of
1287 /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1288 /// alignment that will be valid for the memory operation on a single scalar
1289 /// element of the same type with index \p Idx.
1290 static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1291                                                 Type *ScalarType, Value *Idx,
1292                                                 const DataLayout &DL) {
1293   if (auto *C = dyn_cast<ConstantInt>(Idx))
1294     return commonAlignment(VectorAlignment,
1295                            C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1296   return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1297 }
1298 
1299 // Combine patterns like:
1300 //   %0 = load <4 x i32>, <4 x i32>* %a
1301 //   %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1302 //   store <4 x i32> %1, <4 x i32>* %a
1303 // to:
1304 //   %0 = bitcast <4 x i32>* %a to i32*
1305 //   %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1306 //   store i32 %b, i32* %1
1307 bool VectorCombine::foldSingleElementStore(Instruction &I) {
1308   auto *SI = cast<StoreInst>(&I);
1309   if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1310     return false;
1311 
1312   // TODO: Combine more complicated patterns (multiple insert) by referencing
1313   // TargetTransformInfo.
1314   Instruction *Source;
1315   Value *NewElement;
1316   Value *Idx;
1317   if (!match(SI->getValueOperand(),
1318              m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1319                          m_Value(Idx))))
1320     return false;
1321 
1322   if (auto *Load = dyn_cast<LoadInst>(Source)) {
1323     auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1324     Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1325     // Don't optimize for atomic/volatile load or store. Ensure memory is not
1326     // modified between, vector type matches store size, and index is inbounds.
1327     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1328         !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1329         SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1330       return false;
1331 
1332     auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1333     if (ScalarizableIdx.isUnsafe() ||
1334         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1335                              MemoryLocation::get(SI), AA))
1336       return false;
1337 
1338     if (ScalarizableIdx.isSafeWithFreeze())
1339       ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1340     Value *GEP = Builder.CreateInBoundsGEP(
1341         SI->getValueOperand()->getType(), SI->getPointerOperand(),
1342         {ConstantInt::get(Idx->getType(), 0), Idx});
1343     StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1344     NSI->copyMetadata(*SI);
1345     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1346         std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1347         *DL);
1348     NSI->setAlignment(ScalarOpAlignment);
1349     replaceValue(I, *NSI);
1350     eraseInstruction(I);
1351     return true;
1352   }
1353 
1354   return false;
1355 }
1356 
1357 /// Try to scalarize vector loads feeding extractelement instructions.
1358 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1359   Value *Ptr;
1360   if (!match(&I, m_Load(m_Value(Ptr))))
1361     return false;
1362 
1363   auto *VecTy = cast<VectorType>(I.getType());
1364   auto *LI = cast<LoadInst>(&I);
1365   if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1366     return false;
1367 
1368   InstructionCost OriginalCost =
1369       TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1370                           LI->getPointerAddressSpace(), CostKind);
1371   InstructionCost ScalarizedCost = 0;
1372 
1373   Instruction *LastCheckedInst = LI;
1374   unsigned NumInstChecked = 0;
1375   DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1376   auto FailureGuard = make_scope_exit([&]() {
1377     // If the transform is aborted, discard the ScalarizationResults.
1378     for (auto &Pair : NeedFreeze)
1379       Pair.second.discard();
1380   });
1381 
1382   // Check if all users of the load are extracts with no memory modifications
1383   // between the load and the extract. Compute the cost of both the original
1384   // code and the scalarized version.
1385   for (User *U : LI->users()) {
1386     auto *UI = dyn_cast<ExtractElementInst>(U);
1387     if (!UI || UI->getParent() != LI->getParent())
1388       return false;
1389 
1390     // Check if any instruction between the load and the extract may modify
1391     // memory.
1392     if (LastCheckedInst->comesBefore(UI)) {
1393       for (Instruction &I :
1394            make_range(std::next(LI->getIterator()), UI->getIterator())) {
1395         // Bail out if we reached the check limit or the instruction may write
1396         // to memory.
1397         if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1398           return false;
1399         NumInstChecked++;
1400       }
1401       LastCheckedInst = UI;
1402     }
1403 
1404     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
1405     if (ScalarIdx.isUnsafe())
1406       return false;
1407     if (ScalarIdx.isSafeWithFreeze()) {
1408       NeedFreeze.try_emplace(UI, ScalarIdx);
1409       ScalarIdx.discard();
1410     }
1411 
1412     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
1413     OriginalCost +=
1414         TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1415                                Index ? Index->getZExtValue() : -1);
1416     ScalarizedCost +=
1417         TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1418                             Align(1), LI->getPointerAddressSpace(), CostKind);
1419     ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1420   }
1421 
1422   if (ScalarizedCost >= OriginalCost)
1423     return false;
1424 
1425   // Replace extracts with narrow scalar loads.
1426   for (User *U : LI->users()) {
1427     auto *EI = cast<ExtractElementInst>(U);
1428     Value *Idx = EI->getOperand(1);
1429 
1430     // Insert 'freeze' for poison indexes.
1431     auto It = NeedFreeze.find(EI);
1432     if (It != NeedFreeze.end())
1433       It->second.freeze(Builder, *cast<Instruction>(Idx));
1434 
1435     Builder.SetInsertPoint(EI);
1436     Value *GEP =
1437         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1438     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1439         VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1440 
1441     Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1442         LI->getAlign(), VecTy->getElementType(), Idx, *DL);
1443     NewLoad->setAlignment(ScalarOpAlignment);
1444 
1445     replaceValue(*EI, *NewLoad);
1446   }
1447 
1448   FailureGuard.release();
1449   return true;
1450 }
1451 
1452 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1453 /// to "(bitcast (concat X, Y))"
1454 /// where X/Y are bitcasted from i1 mask vectors.
1455 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1456   Type *Ty = I.getType();
1457   if (!Ty->isIntegerTy())
1458     return false;
1459 
1460   // TODO: Add big endian test coverage
1461   if (DL->isBigEndian())
1462     return false;
1463 
1464   // Restrict to disjoint cases so the mask vectors aren't overlapping.
1465   Instruction *X, *Y;
1466   if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1467     return false;
1468 
1469   // Allow both sources to contain shl, to handle more generic pattern:
1470   // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1471   Value *SrcX;
1472   uint64_t ShAmtX = 0;
1473   if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1474       !match(X, m_OneUse(
1475                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1476                           m_ConstantInt(ShAmtX)))))
1477     return false;
1478 
1479   Value *SrcY;
1480   uint64_t ShAmtY = 0;
1481   if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1482       !match(Y, m_OneUse(
1483                     m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1484                           m_ConstantInt(ShAmtY)))))
1485     return false;
1486 
1487   // Canonicalize larger shift to the RHS.
1488   if (ShAmtX > ShAmtY) {
1489     std::swap(X, Y);
1490     std::swap(SrcX, SrcY);
1491     std::swap(ShAmtX, ShAmtY);
1492   }
1493 
1494   // Ensure both sources are matching vXi1 bool mask types, and that the shift
1495   // difference is the mask width so they can be easily concatenated together.
1496   uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1497   unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1498   unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1499   auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1500   if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1501       !MaskTy->getElementType()->isIntegerTy(1) ||
1502       MaskTy->getNumElements() != ShAmtDiff ||
1503       MaskTy->getNumElements() > (BitWidth / 2))
1504     return false;
1505 
1506   auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1507   auto *ConcatIntTy =
1508       Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1509   auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1510 
1511   SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1512   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1513 
1514   // TODO: Is it worth supporting multi use cases?
1515   InstructionCost OldCost = 0;
1516   OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1517   OldCost +=
1518       NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1519   OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1520                                       TTI::CastContextHint::None, CostKind);
1521   OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1522                                       TTI::CastContextHint::None, CostKind);
1523 
1524   InstructionCost NewCost = 0;
1525   NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy,
1526                                 ConcatMask, CostKind);
1527   NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1528                                   TTI::CastContextHint::None, CostKind);
1529   if (Ty != ConcatIntTy)
1530     NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1531                                     TTI::CastContextHint::None, CostKind);
1532   if (ShAmtX > 0)
1533     NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1534 
1535   LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
1536                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1537                     << "\n");
1538 
1539   if (NewCost > OldCost)
1540     return false;
1541 
1542   // Build bool mask concatenation, bitcast back to scalar integer, and perform
1543   // any residual zero-extension or shifting.
1544   Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1545   Worklist.pushValue(Concat);
1546 
1547   Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1548 
1549   if (Ty != ConcatIntTy) {
1550     Worklist.pushValue(Result);
1551     Result = Builder.CreateZExt(Result, Ty);
1552   }
1553 
1554   if (ShAmtX > 0) {
1555     Worklist.pushValue(Result);
1556     Result = Builder.CreateShl(Result, ShAmtX);
1557   }
1558 
1559   replaceValue(I, *Result);
1560   return true;
1561 }
1562 
1563 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1564 ///           -->  "binop (shuffle), (shuffle)".
1565 bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1566   BinaryOperator *BinOp;
1567   ArrayRef<int> OuterMask;
1568   if (!match(&I,
1569              m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
1570     return false;
1571 
1572   // Don't introduce poison into div/rem.
1573   if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
1574     return false;
1575 
1576   Value *Op00, *Op01;
1577   ArrayRef<int> Mask0;
1578   if (!match(BinOp->getOperand(0),
1579              m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)))))
1580     return false;
1581 
1582   Value *Op10, *Op11;
1583   ArrayRef<int> Mask1;
1584   if (!match(BinOp->getOperand(1),
1585              m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)))))
1586     return false;
1587 
1588   Instruction::BinaryOps Opcode = BinOp->getOpcode();
1589   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1590   auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
1591   auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
1592   auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
1593   if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1594     return false;
1595 
1596   unsigned NumSrcElts = BinOpTy->getNumElements();
1597 
1598   // Don't accept shuffles that reference the second operand in
1599   // div/rem or if its an undef arg.
1600   if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
1601       any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
1602     return false;
1603 
1604   // Merge outer / inner shuffles.
1605   SmallVector<int> NewMask0, NewMask1;
1606   for (int M : OuterMask) {
1607     if (M < 0 || M >= (int)NumSrcElts) {
1608       NewMask0.push_back(PoisonMaskElem);
1609       NewMask1.push_back(PoisonMaskElem);
1610     } else {
1611       NewMask0.push_back(Mask0[M]);
1612       NewMask1.push_back(Mask1[M]);
1613     }
1614   }
1615 
1616   // Try to merge shuffles across the binop if the new shuffles are not costly.
1617   InstructionCost OldCost =
1618       TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
1619       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy,
1620                          OuterMask, CostKind, 0, nullptr, {BinOp}, &I) +
1621       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0,
1622                          CostKind, 0, nullptr, {Op00, Op01},
1623                          cast<Instruction>(BinOp->getOperand(0))) +
1624       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1,
1625                          CostKind, 0, nullptr, {Op10, Op11},
1626                          cast<Instruction>(BinOp->getOperand(1)));
1627 
1628   InstructionCost NewCost =
1629       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0,
1630                          CostKind, 0, nullptr, {Op00, Op01}) +
1631       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1,
1632                          CostKind, 0, nullptr, {Op10, Op11}) +
1633       TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
1634 
1635   LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
1636                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1637                     << "\n");
1638 
1639   // If costs are equal, still fold as we reduce instruction count.
1640   if (NewCost > OldCost)
1641     return false;
1642 
1643   Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0);
1644   Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1);
1645   Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
1646 
1647   // Intersect flags from the old binops.
1648   if (auto *NewInst = dyn_cast<Instruction>(NewBO))
1649     NewInst->copyIRFlags(BinOp);
1650 
1651   Worklist.pushValue(Shuf0);
1652   Worklist.pushValue(Shuf1);
1653   replaceValue(I, *NewBO);
1654   return true;
1655 }
1656 
1657 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
1658 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
1659 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1660   ArrayRef<int> OldMask;
1661   Instruction *LHS, *RHS;
1662   if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)),
1663                            m_OneUse(m_Instruction(RHS)), m_Mask(OldMask))))
1664     return false;
1665 
1666   // TODO: Add support for addlike etc.
1667   if (LHS->getOpcode() != RHS->getOpcode())
1668     return false;
1669 
1670   Value *X, *Y, *Z, *W;
1671   bool IsCommutative = false;
1672   CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
1673   CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
1674   if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
1675       match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
1676     auto *BO = cast<BinaryOperator>(LHS);
1677     // Don't introduce poison into div/rem.
1678     if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
1679       return false;
1680     IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
1681   } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
1682              match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
1683              (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
1684     IsCommutative = cast<CmpInst>(LHS)->isCommutative();
1685   } else
1686     return false;
1687 
1688   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1689   auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
1690   auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
1691   if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
1692     return false;
1693 
1694   unsigned NumSrcElts = BinOpTy->getNumElements();
1695 
1696   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
1697   if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
1698     std::swap(X, Y);
1699 
1700   auto ConvertToUnary = [NumSrcElts](int &M) {
1701     if (M >= (int)NumSrcElts)
1702       M -= NumSrcElts;
1703   };
1704 
1705   SmallVector<int> NewMask0(OldMask);
1706   TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
1707   if (X == Z) {
1708     llvm::for_each(NewMask0, ConvertToUnary);
1709     SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
1710     Z = PoisonValue::get(BinOpTy);
1711   }
1712 
1713   SmallVector<int> NewMask1(OldMask);
1714   TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
1715   if (Y == W) {
1716     llvm::for_each(NewMask1, ConvertToUnary);
1717     SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
1718     W = PoisonValue::get(BinOpTy);
1719   }
1720 
1721   // Try to replace a binop with a shuffle if the shuffle is not costly.
1722   InstructionCost OldCost =
1723       TTI.getInstructionCost(LHS, CostKind) +
1724       TTI.getInstructionCost(RHS, CostKind) +
1725       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy,
1726                          OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I);
1727 
1728   InstructionCost NewCost =
1729       TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
1730       TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W});
1731 
1732   if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
1733     NewCost +=
1734         TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
1735   } else {
1736     auto *ShuffleCmpTy =
1737         FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
1738     NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
1739                                       ShuffleDstTy, PredLHS, CostKind);
1740   }
1741 
1742   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
1743                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1744                     << "\n");
1745 
1746   // If either shuffle will constant fold away, then fold for the same cost as
1747   // we will reduce the instruction count.
1748   bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) ||
1749                           (isa<Constant>(Y) && isa<Constant>(W));
1750   if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
1751     return false;
1752 
1753   Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
1754   Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
1755   Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
1756                      ? Builder.CreateBinOp(
1757                            cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
1758                      : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
1759 
1760   // Intersect flags from the old binops.
1761   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
1762     NewInst->copyIRFlags(LHS);
1763     NewInst->andIRFlags(RHS);
1764   }
1765 
1766   Worklist.pushValue(Shuf0);
1767   Worklist.pushValue(Shuf1);
1768   replaceValue(I, *NewBO);
1769   return true;
1770 }
1771 
1772 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
1773 /// into "castop (shuffle)".
1774 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
1775   Value *V0, *V1;
1776   ArrayRef<int> OldMask;
1777   if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
1778     return false;
1779 
1780   auto *C0 = dyn_cast<CastInst>(V0);
1781   auto *C1 = dyn_cast<CastInst>(V1);
1782   if (!C0 || !C1)
1783     return false;
1784 
1785   Instruction::CastOps Opcode = C0->getOpcode();
1786   if (C0->getSrcTy() != C1->getSrcTy())
1787     return false;
1788 
1789   // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
1790   if (Opcode != C1->getOpcode()) {
1791     if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
1792       Opcode = Instruction::SExt;
1793     else
1794       return false;
1795   }
1796 
1797   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1798   auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
1799   auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
1800   if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
1801     return false;
1802 
1803   unsigned NumSrcElts = CastSrcTy->getNumElements();
1804   unsigned NumDstElts = CastDstTy->getNumElements();
1805   assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
1806          "Only bitcasts expected to alter src/dst element counts");
1807 
1808   // Check for bitcasting of unscalable vector types.
1809   // e.g. <32 x i40> -> <40 x i32>
1810   if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
1811       (NumDstElts % NumSrcElts) != 0)
1812     return false;
1813 
1814   SmallVector<int, 16> NewMask;
1815   if (NumSrcElts >= NumDstElts) {
1816     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1817     // always be expanded to the equivalent form choosing narrower elements.
1818     assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
1819     unsigned ScaleFactor = NumSrcElts / NumDstElts;
1820     narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
1821   } else {
1822     // The bitcast is from narrow elements to wide elements. The shuffle mask
1823     // must choose consecutive elements to allow casting first.
1824     assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
1825     unsigned ScaleFactor = NumDstElts / NumSrcElts;
1826     if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
1827       return false;
1828   }
1829 
1830   auto *NewShuffleDstTy =
1831       FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
1832 
1833   // Try to replace a castop with a shuffle if the shuffle is not costly.
1834   InstructionCost CostC0 =
1835       TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
1836                            TTI::CastContextHint::None, CostKind);
1837   InstructionCost CostC1 =
1838       TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
1839                            TTI::CastContextHint::None, CostKind);
1840   InstructionCost OldCost = CostC0 + CostC1;
1841   OldCost +=
1842       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
1843                          OldMask, CostKind, 0, nullptr, {}, &I);
1844 
1845   InstructionCost NewCost = TTI.getShuffleCost(
1846       TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
1847   NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
1848                                   TTI::CastContextHint::None, CostKind);
1849   if (!C0->hasOneUse())
1850     NewCost += CostC0;
1851   if (!C1->hasOneUse())
1852     NewCost += CostC1;
1853 
1854   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
1855                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1856                     << "\n");
1857   if (NewCost > OldCost)
1858     return false;
1859 
1860   Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
1861                                             C1->getOperand(0), NewMask);
1862   Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
1863 
1864   // Intersect flags from the old casts.
1865   if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
1866     NewInst->copyIRFlags(C0);
1867     NewInst->andIRFlags(C1);
1868   }
1869 
1870   Worklist.pushValue(Shuf);
1871   replaceValue(I, *Cast);
1872   return true;
1873 }
1874 
1875 /// Try to convert any of:
1876 /// "shuffle (shuffle x, y), (shuffle y, x)"
1877 /// "shuffle (shuffle x, undef), (shuffle y, undef)"
1878 /// "shuffle (shuffle x, undef), y"
1879 /// "shuffle x, (shuffle y, undef)"
1880 /// into "shuffle x, y".
1881 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1882   ArrayRef<int> OuterMask;
1883   Value *OuterV0, *OuterV1;
1884   if (!match(&I,
1885              m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
1886     return false;
1887 
1888   ArrayRef<int> InnerMask0, InnerMask1;
1889   Value *X0, *X1, *Y0, *Y1;
1890   bool Match0 =
1891       match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
1892   bool Match1 =
1893       match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
1894   if (!Match0 && !Match1)
1895     return false;
1896 
1897   X0 = Match0 ? X0 : OuterV0;
1898   Y0 = Match0 ? Y0 : OuterV0;
1899   X1 = Match1 ? X1 : OuterV1;
1900   Y1 = Match1 ? Y1 : OuterV1;
1901   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1902   auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
1903   auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
1904   if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
1905       X0->getType() != X1->getType())
1906     return false;
1907 
1908   unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
1909   unsigned NumImmElts = ShuffleImmTy->getNumElements();
1910 
1911   // Attempt to merge shuffles, matching upto 2 source operands.
1912   // Replace index to a poison arg with PoisonMaskElem.
1913   // Bail if either inner masks reference an undef arg.
1914   SmallVector<int, 16> NewMask(OuterMask);
1915   Value *NewX = nullptr, *NewY = nullptr;
1916   for (int &M : NewMask) {
1917     Value *Src = nullptr;
1918     if (0 <= M && M < (int)NumImmElts) {
1919       Src = OuterV0;
1920       if (Match0) {
1921         M = InnerMask0[M];
1922         Src = M >= (int)NumSrcElts ? Y0 : X0;
1923         M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
1924       }
1925     } else if (M >= (int)NumImmElts) {
1926       Src = OuterV1;
1927       M -= NumImmElts;
1928       if (Match1) {
1929         M = InnerMask1[M];
1930         Src = M >= (int)NumSrcElts ? Y1 : X1;
1931         M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
1932       }
1933     }
1934     if (Src && M != PoisonMaskElem) {
1935       assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
1936       if (isa<UndefValue>(Src)) {
1937         // We've referenced an undef element - if its poison, update the shuffle
1938         // mask, else bail.
1939         if (!isa<PoisonValue>(Src))
1940           return false;
1941         M = PoisonMaskElem;
1942         continue;
1943       }
1944       if (!NewX || NewX == Src) {
1945         NewX = Src;
1946         continue;
1947       }
1948       if (!NewY || NewY == Src) {
1949         M += NumSrcElts;
1950         NewY = Src;
1951         continue;
1952       }
1953       return false;
1954     }
1955   }
1956 
1957   if (!NewX)
1958     return PoisonValue::get(ShuffleDstTy);
1959   if (!NewY)
1960     NewY = PoisonValue::get(ShuffleSrcTy);
1961 
1962   // Have we folded to an Identity shuffle?
1963   if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
1964     replaceValue(I, *NewX);
1965     return true;
1966   }
1967 
1968   // Try to merge the shuffles if the new shuffle is not costly.
1969   InstructionCost InnerCost0 = 0;
1970   if (Match0)
1971     InnerCost0 = TTI.getInstructionCost(cast<Instruction>(OuterV0), CostKind);
1972 
1973   InstructionCost InnerCost1 = 0;
1974   if (Match1)
1975     InnerCost1 = TTI.getInstructionCost(cast<Instruction>(OuterV1), CostKind);
1976 
1977   InstructionCost OuterCost = TTI.getShuffleCost(
1978       TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind,
1979       0, nullptr, {OuterV0, OuterV1}, &I);
1980 
1981   InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
1982 
1983   bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; });
1984   TargetTransformInfo::ShuffleKind SK =
1985       IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
1986               : TargetTransformInfo::SK_PermuteTwoSrc;
1987   InstructionCost NewCost = TTI.getShuffleCost(
1988       SK, ShuffleSrcTy, NewMask, CostKind, 0, nullptr, {NewX, NewY});
1989   if (!OuterV0->hasOneUse())
1990     NewCost += InnerCost0;
1991   if (!OuterV1->hasOneUse())
1992     NewCost += InnerCost1;
1993 
1994   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
1995                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
1996                     << "\n");
1997   if (NewCost > OldCost)
1998     return false;
1999 
2000   Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask);
2001   replaceValue(I, *Shuf);
2002   return true;
2003 }
2004 
2005 /// Try to convert
2006 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
2007 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
2008   Value *V0, *V1;
2009   ArrayRef<int> OldMask;
2010   if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
2011                            m_Mask(OldMask))))
2012     return false;
2013 
2014   auto *II0 = dyn_cast<IntrinsicInst>(V0);
2015   auto *II1 = dyn_cast<IntrinsicInst>(V1);
2016   if (!II0 || !II1)
2017     return false;
2018 
2019   Intrinsic::ID IID = II0->getIntrinsicID();
2020   if (IID != II1->getIntrinsicID())
2021     return false;
2022 
2023   auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2024   auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
2025   if (!ShuffleDstTy || !II0Ty)
2026     return false;
2027 
2028   if (!isTriviallyVectorizable(IID))
2029     return false;
2030 
2031   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2032     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) &&
2033         II0->getArgOperand(I) != II1->getArgOperand(I))
2034       return false;
2035 
2036   InstructionCost OldCost =
2037       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
2038       TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) +
2039       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask,
2040                          CostKind, 0, nullptr, {II0, II1}, &I);
2041 
2042   SmallVector<Type *> NewArgsTy;
2043   InstructionCost NewCost = 0;
2044   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2045     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2046       NewArgsTy.push_back(II0->getArgOperand(I)->getType());
2047     } else {
2048       auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
2049       NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(),
2050                                                VecTy->getNumElements() * 2));
2051       NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
2052                                     VecTy, OldMask, CostKind);
2053     }
2054   IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
2055   NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
2056 
2057   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
2058                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
2059                     << "\n");
2060 
2061   if (NewCost > OldCost)
2062     return false;
2063 
2064   SmallVector<Value *> NewArgs;
2065   for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2066     if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2067       NewArgs.push_back(II0->getArgOperand(I));
2068     } else {
2069       Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
2070                                                 II1->getArgOperand(I), OldMask);
2071       NewArgs.push_back(Shuf);
2072       Worklist.pushValue(Shuf);
2073     }
2074   Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
2075 
2076   // Intersect flags from the old intrinsics.
2077   if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
2078     NewInst->copyIRFlags(II0);
2079     NewInst->andIRFlags(II1);
2080   }
2081 
2082   replaceValue(I, *NewIntrinsic);
2083   return true;
2084 }
2085 
2086 using InstLane = std::pair<Use *, int>;
2087 
2088 static InstLane lookThroughShuffles(Use *U, int Lane) {
2089   while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) {
2090     unsigned NumElts =
2091         cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
2092     int M = SV->getMaskValue(Lane);
2093     if (M < 0)
2094       return {nullptr, PoisonMaskElem};
2095     if (static_cast<unsigned>(M) < NumElts) {
2096       U = &SV->getOperandUse(0);
2097       Lane = M;
2098     } else {
2099       U = &SV->getOperandUse(1);
2100       Lane = M - NumElts;
2101     }
2102   }
2103   return InstLane{U, Lane};
2104 }
2105 
2106 static SmallVector<InstLane>
2107 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
2108   SmallVector<InstLane> NItem;
2109   for (InstLane IL : Item) {
2110     auto [U, Lane] = IL;
2111     InstLane OpLane =
2112         U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op),
2113                                 Lane)
2114           : InstLane{nullptr, PoisonMaskElem};
2115     NItem.emplace_back(OpLane);
2116   }
2117   return NItem;
2118 }
2119 
2120 /// Detect concat of multiple values into a vector
2121 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
2122                          const TargetTransformInfo &TTI) {
2123   auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
2124   unsigned NumElts = Ty->getNumElements();
2125   if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
2126     return false;
2127 
2128   // Check that the concat is free, usually meaning that the type will be split
2129   // during legalization.
2130   SmallVector<int, 16> ConcatMask(NumElts * 2);
2131   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2132   if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0)
2133     return false;
2134 
2135   unsigned NumSlices = Item.size() / NumElts;
2136   // Currently we generate a tree of shuffles for the concats, which limits us
2137   // to a power2.
2138   if (!isPowerOf2_32(NumSlices))
2139     return false;
2140   for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
2141     Use *SliceV = Item[Slice * NumElts].first;
2142     if (!SliceV || SliceV->get()->getType() != Ty)
2143       return false;
2144     for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
2145       auto [V, Lane] = Item[Slice * NumElts + Elt];
2146       if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
2147         return false;
2148     }
2149   }
2150   return true;
2151 }
2152 
2153 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
2154                                   const SmallPtrSet<Use *, 4> &IdentityLeafs,
2155                                   const SmallPtrSet<Use *, 4> &SplatLeafs,
2156                                   const SmallPtrSet<Use *, 4> &ConcatLeafs,
2157                                   IRBuilder<> &Builder,
2158                                   const TargetTransformInfo *TTI) {
2159   auto [FrontU, FrontLane] = Item.front();
2160 
2161   if (IdentityLeafs.contains(FrontU)) {
2162     return FrontU->get();
2163   }
2164   if (SplatLeafs.contains(FrontU)) {
2165     SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
2166     return Builder.CreateShuffleVector(FrontU->get(), Mask);
2167   }
2168   if (ConcatLeafs.contains(FrontU)) {
2169     unsigned NumElts =
2170         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
2171     SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
2172     for (unsigned S = 0; S < Values.size(); ++S)
2173       Values[S] = Item[S * NumElts].first->get();
2174 
2175     while (Values.size() > 1) {
2176       NumElts *= 2;
2177       SmallVector<int, 16> Mask(NumElts, 0);
2178       std::iota(Mask.begin(), Mask.end(), 0);
2179       SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
2180       for (unsigned S = 0; S < NewValues.size(); ++S)
2181         NewValues[S] =
2182             Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
2183       Values = NewValues;
2184     }
2185     return Values[0];
2186   }
2187 
2188   auto *I = cast<Instruction>(FrontU->get());
2189   auto *II = dyn_cast<IntrinsicInst>(I);
2190   unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
2191   SmallVector<Value *> Ops(NumOps);
2192   for (unsigned Idx = 0; Idx < NumOps; Idx++) {
2193     if (II &&
2194         isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
2195       Ops[Idx] = II->getOperand(Idx);
2196       continue;
2197     }
2198     Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
2199                                    Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
2200                                    Builder, TTI);
2201   }
2202 
2203   SmallVector<Value *, 8> ValueList;
2204   for (const auto &Lane : Item)
2205     if (Lane.first)
2206       ValueList.push_back(Lane.first->get());
2207 
2208   Type *DstTy =
2209       FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
2210   if (auto *BI = dyn_cast<BinaryOperator>(I)) {
2211     auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
2212                                       Ops[0], Ops[1]);
2213     propagateIRFlags(Value, ValueList);
2214     return Value;
2215   }
2216   if (auto *CI = dyn_cast<CmpInst>(I)) {
2217     auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
2218     propagateIRFlags(Value, ValueList);
2219     return Value;
2220   }
2221   if (auto *SI = dyn_cast<SelectInst>(I)) {
2222     auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
2223     propagateIRFlags(Value, ValueList);
2224     return Value;
2225   }
2226   if (auto *CI = dyn_cast<CastInst>(I)) {
2227     auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
2228                                      Ops[0], DstTy);
2229     propagateIRFlags(Value, ValueList);
2230     return Value;
2231   }
2232   if (II) {
2233     auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
2234     propagateIRFlags(Value, ValueList);
2235     return Value;
2236   }
2237   assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
2238   auto *Value =
2239       Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
2240   propagateIRFlags(Value, ValueList);
2241   return Value;
2242 }
2243 
2244 // Starting from a shuffle, look up through operands tracking the shuffled index
2245 // of each lane. If we can simplify away the shuffles to identities then
2246 // do so.
2247 bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
2248   auto *Ty = dyn_cast<FixedVectorType>(I.getType());
2249   if (!Ty || I.use_empty())
2250     return false;
2251 
2252   SmallVector<InstLane> Start(Ty->getNumElements());
2253   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
2254     Start[M] = lookThroughShuffles(&*I.use_begin(), M);
2255 
2256   SmallVector<SmallVector<InstLane>> Worklist;
2257   Worklist.push_back(Start);
2258   SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
2259   unsigned NumVisited = 0;
2260 
2261   while (!Worklist.empty()) {
2262     if (++NumVisited > MaxInstrsToScan)
2263       return false;
2264 
2265     SmallVector<InstLane> Item = Worklist.pop_back_val();
2266     auto [FrontU, FrontLane] = Item.front();
2267 
2268     // If we found an undef first lane then bail out to keep things simple.
2269     if (!FrontU)
2270       return false;
2271 
2272     // Helper to peek through bitcasts to the same value.
2273     auto IsEquiv = [&](Value *X, Value *Y) {
2274       return X->getType() == Y->getType() &&
2275              peekThroughBitcasts(X) == peekThroughBitcasts(Y);
2276     };
2277 
2278     // Look for an identity value.
2279     if (FrontLane == 0 &&
2280         cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() ==
2281             Ty->getNumElements() &&
2282         all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
2283           Value *FrontV = Item.front().first->get();
2284           return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
2285                                       E.value().second == (int)E.index());
2286         })) {
2287       IdentityLeafs.insert(FrontU);
2288       continue;
2289     }
2290     // Look for constants, for the moment only supporting constant splats.
2291     if (auto *C = dyn_cast<Constant>(FrontU);
2292         C && C->getSplatValue() &&
2293         all_of(drop_begin(Item), [Item](InstLane &IL) {
2294           Value *FrontV = Item.front().first->get();
2295           Use *U = IL.first;
2296           return !U || (isa<Constant>(U->get()) &&
2297                         cast<Constant>(U->get())->getSplatValue() ==
2298                             cast<Constant>(FrontV)->getSplatValue());
2299         })) {
2300       SplatLeafs.insert(FrontU);
2301       continue;
2302     }
2303     // Look for a splat value.
2304     if (all_of(drop_begin(Item), [Item](InstLane &IL) {
2305           auto [FrontU, FrontLane] = Item.front();
2306           auto [U, Lane] = IL;
2307           return !U || (U->get() == FrontU->get() && Lane == FrontLane);
2308         })) {
2309       SplatLeafs.insert(FrontU);
2310       continue;
2311     }
2312 
2313     // We need each element to be the same type of value, and check that each
2314     // element has a single use.
2315     auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
2316       Value *FrontV = Item.front().first->get();
2317       if (!IL.first)
2318         return true;
2319       Value *V = IL.first->get();
2320       if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
2321         return false;
2322       if (V->getValueID() != FrontV->getValueID())
2323         return false;
2324       if (auto *CI = dyn_cast<CmpInst>(V))
2325         if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
2326           return false;
2327       if (auto *CI = dyn_cast<CastInst>(V))
2328         if (CI->getSrcTy()->getScalarType() !=
2329             cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
2330           return false;
2331       if (auto *SI = dyn_cast<SelectInst>(V))
2332         if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
2333             SI->getOperand(0)->getType() !=
2334                 cast<SelectInst>(FrontV)->getOperand(0)->getType())
2335           return false;
2336       if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
2337         return false;
2338       auto *II = dyn_cast<IntrinsicInst>(V);
2339       return !II || (isa<IntrinsicInst>(FrontV) &&
2340                      II->getIntrinsicID() ==
2341                          cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
2342                      !II->hasOperandBundles());
2343     };
2344     if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
2345       // Check the operator is one that we support.
2346       if (isa<BinaryOperator, CmpInst>(FrontU)) {
2347         //  We exclude div/rem in case they hit UB from poison lanes.
2348         if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
2349             BO && BO->isIntDivRem())
2350           return false;
2351         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2352         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2353         continue;
2354       } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
2355                      FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) {
2356         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2357         continue;
2358       } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) {
2359         // TODO: Handle vector widening/narrowing bitcasts.
2360         auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
2361         auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
2362         if (DstTy && SrcTy &&
2363             SrcTy->getNumElements() == DstTy->getNumElements()) {
2364           Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2365           continue;
2366         }
2367       } else if (isa<SelectInst>(FrontU)) {
2368         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2369         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2370         Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
2371         continue;
2372       } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
2373                  II && isTriviallyVectorizable(II->getIntrinsicID()) &&
2374                  !II->hasOperandBundles()) {
2375         for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
2376           if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
2377                                                  &TTI)) {
2378             if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
2379                   Value *FrontV = Item.front().first->get();
2380                   Use *U = IL.first;
2381                   return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
2382                                 cast<Instruction>(FrontV)->getOperand(Op));
2383                 }))
2384               return false;
2385             continue;
2386           }
2387           Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
2388         }
2389         continue;
2390       }
2391     }
2392 
2393     if (isFreeConcat(Item, CostKind, TTI)) {
2394       ConcatLeafs.insert(FrontU);
2395       continue;
2396     }
2397 
2398     return false;
2399   }
2400 
2401   if (NumVisited <= 1)
2402     return false;
2403 
2404   LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
2405 
2406   // If we got this far, we know the shuffles are superfluous and can be
2407   // removed. Scan through again and generate the new tree of instructions.
2408   Builder.SetInsertPoint(&I);
2409   Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
2410                                  ConcatLeafs, Builder, &TTI);
2411   replaceValue(I, *V);
2412   return true;
2413 }
2414 
2415 /// Given a commutative reduction, the order of the input lanes does not alter
2416 /// the results. We can use this to remove certain shuffles feeding the
2417 /// reduction, removing the need to shuffle at all.
2418 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2419   auto *II = dyn_cast<IntrinsicInst>(&I);
2420   if (!II)
2421     return false;
2422   switch (II->getIntrinsicID()) {
2423   case Intrinsic::vector_reduce_add:
2424   case Intrinsic::vector_reduce_mul:
2425   case Intrinsic::vector_reduce_and:
2426   case Intrinsic::vector_reduce_or:
2427   case Intrinsic::vector_reduce_xor:
2428   case Intrinsic::vector_reduce_smin:
2429   case Intrinsic::vector_reduce_smax:
2430   case Intrinsic::vector_reduce_umin:
2431   case Intrinsic::vector_reduce_umax:
2432     break;
2433   default:
2434     return false;
2435   }
2436 
2437   // Find all the inputs when looking through operations that do not alter the
2438   // lane order (binops, for example). Currently we look for a single shuffle,
2439   // and can ignore splat values.
2440   std::queue<Value *> Worklist;
2441   SmallPtrSet<Value *, 4> Visited;
2442   ShuffleVectorInst *Shuffle = nullptr;
2443   if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
2444     Worklist.push(Op);
2445 
2446   while (!Worklist.empty()) {
2447     Value *CV = Worklist.front();
2448     Worklist.pop();
2449     if (Visited.contains(CV))
2450       continue;
2451 
2452     // Splats don't change the order, so can be safely ignored.
2453     if (isSplatValue(CV))
2454       continue;
2455 
2456     Visited.insert(CV);
2457 
2458     if (auto *CI = dyn_cast<Instruction>(CV)) {
2459       if (CI->isBinaryOp()) {
2460         for (auto *Op : CI->operand_values())
2461           Worklist.push(Op);
2462         continue;
2463       } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
2464         if (Shuffle && Shuffle != SV)
2465           return false;
2466         Shuffle = SV;
2467         continue;
2468       }
2469     }
2470 
2471     // Anything else is currently an unknown node.
2472     return false;
2473   }
2474 
2475   if (!Shuffle)
2476     return false;
2477 
2478   // Check all uses of the binary ops and shuffles are also included in the
2479   // lane-invariant operations (Visited should be the list of lanewise
2480   // instructions, including the shuffle that we found).
2481   for (auto *V : Visited)
2482     for (auto *U : V->users())
2483       if (!Visited.contains(U) && U != &I)
2484         return false;
2485 
2486   FixedVectorType *VecType =
2487       dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
2488   if (!VecType)
2489     return false;
2490   FixedVectorType *ShuffleInputType =
2491       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
2492   if (!ShuffleInputType)
2493     return false;
2494   unsigned NumInputElts = ShuffleInputType->getNumElements();
2495 
2496   // Find the mask from sorting the lanes into order. This is most likely to
2497   // become a identity or concat mask. Undef elements are pushed to the end.
2498   SmallVector<int> ConcatMask;
2499   Shuffle->getShuffleMask(ConcatMask);
2500   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
2501   // In the case of a truncating shuffle it's possible for the mask
2502   // to have an index greater than the size of the resulting vector.
2503   // This requires special handling.
2504   bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
2505   bool UsesSecondVec =
2506       any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
2507 
2508   FixedVectorType *VecTyForCost =
2509       (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType;
2510   InstructionCost OldCost = TTI.getShuffleCost(
2511       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2512       VecTyForCost, Shuffle->getShuffleMask(), CostKind);
2513   InstructionCost NewCost = TTI.getShuffleCost(
2514       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
2515       VecTyForCost, ConcatMask, CostKind);
2516 
2517   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
2518                     << "\n");
2519   LLVM_DEBUG(dbgs() << "  OldCost: " << OldCost << " vs NewCost: " << NewCost
2520                     << "\n");
2521   if (NewCost < OldCost) {
2522     Builder.SetInsertPoint(Shuffle);
2523     Value *NewShuffle = Builder.CreateShuffleVector(
2524         Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
2525     LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
2526     replaceValue(*Shuffle, *NewShuffle);
2527   }
2528 
2529   // See if we can re-use foldSelectShuffle, getting it to reduce the size of
2530   // the shuffle into a nicer order, as it can ignore the order of the shuffles.
2531   return foldSelectShuffle(*Shuffle, true);
2532 }
2533 
2534 /// Determine if its more efficient to fold:
2535 ///   reduce(trunc(x)) -> trunc(reduce(x)).
2536 ///   reduce(sext(x))  -> sext(reduce(x)).
2537 ///   reduce(zext(x))  -> zext(reduce(x)).
2538 bool VectorCombine::foldCastFromReductions(Instruction &I) {
2539   auto *II = dyn_cast<IntrinsicInst>(&I);
2540   if (!II)
2541     return false;
2542 
2543   bool TruncOnly = false;
2544   Intrinsic::ID IID = II->getIntrinsicID();
2545   switch (IID) {
2546   case Intrinsic::vector_reduce_add:
2547   case Intrinsic::vector_reduce_mul:
2548     TruncOnly = true;
2549     break;
2550   case Intrinsic::vector_reduce_and:
2551   case Intrinsic::vector_reduce_or:
2552   case Intrinsic::vector_reduce_xor:
2553     break;
2554   default:
2555     return false;
2556   }
2557 
2558   unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
2559   Value *ReductionSrc = I.getOperand(0);
2560 
2561   Value *Src;
2562   if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
2563       (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
2564     return false;
2565 
2566   auto CastOpc =
2567       (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
2568 
2569   auto *SrcTy = cast<VectorType>(Src->getType());
2570   auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
2571   Type *ResultTy = I.getType();
2572 
2573   InstructionCost OldCost = TTI.getArithmeticReductionCost(
2574       ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
2575   OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
2576                                   TTI::CastContextHint::None, CostKind,
2577                                   cast<CastInst>(ReductionSrc));
2578   InstructionCost NewCost =
2579       TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
2580                                      CostKind) +
2581       TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
2582                            TTI::CastContextHint::None, CostKind);
2583 
2584   if (OldCost <= NewCost || !NewCost.isValid())
2585     return false;
2586 
2587   Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
2588                                                 II->getIntrinsicID(), {Src});
2589   Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
2590   replaceValue(I, *NewCast);
2591   return true;
2592 }
2593 
2594 /// This method looks for groups of shuffles acting on binops, of the form:
2595 ///  %x = shuffle ...
2596 ///  %y = shuffle ...
2597 ///  %a = binop %x, %y
2598 ///  %b = binop %x, %y
2599 ///  shuffle %a, %b, selectmask
2600 /// We may, especially if the shuffle is wider than legal, be able to convert
2601 /// the shuffle to a form where only parts of a and b need to be computed. On
2602 /// architectures with no obvious "select" shuffle, this can reduce the total
2603 /// number of operations if the target reports them as cheaper.
2604 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
2605   auto *SVI = cast<ShuffleVectorInst>(&I);
2606   auto *VT = cast<FixedVectorType>(I.getType());
2607   auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
2608   auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
2609   if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
2610       VT != Op0->getType())
2611     return false;
2612 
2613   auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
2614   auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
2615   auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
2616   auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
2617   SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
2618   auto checkSVNonOpUses = [&](Instruction *I) {
2619     if (!I || I->getOperand(0)->getType() != VT)
2620       return true;
2621     return any_of(I->users(), [&](User *U) {
2622       return U != Op0 && U != Op1 &&
2623              !(isa<ShuffleVectorInst>(U) &&
2624                (InputShuffles.contains(cast<Instruction>(U)) ||
2625                 isInstructionTriviallyDead(cast<Instruction>(U))));
2626     });
2627   };
2628   if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
2629       checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
2630     return false;
2631 
2632   // Collect all the uses that are shuffles that we can transform together. We
2633   // may not have a single shuffle, but a group that can all be transformed
2634   // together profitably.
2635   SmallVector<ShuffleVectorInst *> Shuffles;
2636   auto collectShuffles = [&](Instruction *I) {
2637     for (auto *U : I->users()) {
2638       auto *SV = dyn_cast<ShuffleVectorInst>(U);
2639       if (!SV || SV->getType() != VT)
2640         return false;
2641       if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
2642           (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
2643         return false;
2644       if (!llvm::is_contained(Shuffles, SV))
2645         Shuffles.push_back(SV);
2646     }
2647     return true;
2648   };
2649   if (!collectShuffles(Op0) || !collectShuffles(Op1))
2650     return false;
2651   // From a reduction, we need to be processing a single shuffle, otherwise the
2652   // other uses will not be lane-invariant.
2653   if (FromReduction && Shuffles.size() > 1)
2654     return false;
2655 
2656   // Add any shuffle uses for the shuffles we have found, to include them in our
2657   // cost calculations.
2658   if (!FromReduction) {
2659     for (ShuffleVectorInst *SV : Shuffles) {
2660       for (auto *U : SV->users()) {
2661         ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
2662         if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
2663           Shuffles.push_back(SSV);
2664       }
2665     }
2666   }
2667 
2668   // For each of the output shuffles, we try to sort all the first vector
2669   // elements to the beginning, followed by the second array elements at the
2670   // end. If the binops are legalized to smaller vectors, this may reduce total
2671   // number of binops. We compute the ReconstructMask mask needed to convert
2672   // back to the original lane order.
2673   SmallVector<std::pair<int, int>> V1, V2;
2674   SmallVector<SmallVector<int>> OrigReconstructMasks;
2675   int MaxV1Elt = 0, MaxV2Elt = 0;
2676   unsigned NumElts = VT->getNumElements();
2677   for (ShuffleVectorInst *SVN : Shuffles) {
2678     SmallVector<int> Mask;
2679     SVN->getShuffleMask(Mask);
2680 
2681     // Check the operands are the same as the original, or reversed (in which
2682     // case we need to commute the mask).
2683     Value *SVOp0 = SVN->getOperand(0);
2684     Value *SVOp1 = SVN->getOperand(1);
2685     if (isa<UndefValue>(SVOp1)) {
2686       auto *SSV = cast<ShuffleVectorInst>(SVOp0);
2687       SVOp0 = SSV->getOperand(0);
2688       SVOp1 = SSV->getOperand(1);
2689       for (unsigned I = 0, E = Mask.size(); I != E; I++) {
2690         if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size()))
2691           return false;
2692         Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]);
2693       }
2694     }
2695     if (SVOp0 == Op1 && SVOp1 == Op0) {
2696       std::swap(SVOp0, SVOp1);
2697       ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
2698     }
2699     if (SVOp0 != Op0 || SVOp1 != Op1)
2700       return false;
2701 
2702     // Calculate the reconstruction mask for this shuffle, as the mask needed to
2703     // take the packed values from Op0/Op1 and reconstructing to the original
2704     // order.
2705     SmallVector<int> ReconstructMask;
2706     for (unsigned I = 0; I < Mask.size(); I++) {
2707       if (Mask[I] < 0) {
2708         ReconstructMask.push_back(-1);
2709       } else if (Mask[I] < static_cast<int>(NumElts)) {
2710         MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
2711         auto It = find_if(V1, [&](const std::pair<int, int> &A) {
2712           return Mask[I] == A.first;
2713         });
2714         if (It != V1.end())
2715           ReconstructMask.push_back(It - V1.begin());
2716         else {
2717           ReconstructMask.push_back(V1.size());
2718           V1.emplace_back(Mask[I], V1.size());
2719         }
2720       } else {
2721         MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
2722         auto It = find_if(V2, [&](const std::pair<int, int> &A) {
2723           return Mask[I] - static_cast<int>(NumElts) == A.first;
2724         });
2725         if (It != V2.end())
2726           ReconstructMask.push_back(NumElts + It - V2.begin());
2727         else {
2728           ReconstructMask.push_back(NumElts + V2.size());
2729           V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
2730         }
2731       }
2732     }
2733 
2734     // For reductions, we know that the lane ordering out doesn't alter the
2735     // result. In-order can help simplify the shuffle away.
2736     if (FromReduction)
2737       sort(ReconstructMask);
2738     OrigReconstructMasks.push_back(std::move(ReconstructMask));
2739   }
2740 
2741   // If the Maximum element used from V1 and V2 are not larger than the new
2742   // vectors, the vectors are already packes and performing the optimization
2743   // again will likely not help any further. This also prevents us from getting
2744   // stuck in a cycle in case the costs do not also rule it out.
2745   if (V1.empty() || V2.empty() ||
2746       (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
2747        MaxV2Elt == static_cast<int>(V2.size()) - 1))
2748     return false;
2749 
2750   // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
2751   // shuffle of another shuffle, or not a shuffle (that is treated like a
2752   // identity shuffle).
2753   auto GetBaseMaskValue = [&](Instruction *I, int M) {
2754     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2755     if (!SV)
2756       return M;
2757     if (isa<UndefValue>(SV->getOperand(1)))
2758       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2759         if (InputShuffles.contains(SSV))
2760           return SSV->getMaskValue(SV->getMaskValue(M));
2761     return SV->getMaskValue(M);
2762   };
2763 
2764   // Attempt to sort the inputs my ascending mask values to make simpler input
2765   // shuffles and push complex shuffles down to the uses. We sort on the first
2766   // of the two input shuffle orders, to try and get at least one input into a
2767   // nice order.
2768   auto SortBase = [&](Instruction *A, std::pair<int, int> X,
2769                       std::pair<int, int> Y) {
2770     int MXA = GetBaseMaskValue(A, X.first);
2771     int MYA = GetBaseMaskValue(A, Y.first);
2772     return MXA < MYA;
2773   };
2774   stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
2775     return SortBase(SVI0A, A, B);
2776   });
2777   stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
2778     return SortBase(SVI1A, A, B);
2779   });
2780   // Calculate our ReconstructMasks from the OrigReconstructMasks and the
2781   // modified order of the input shuffles.
2782   SmallVector<SmallVector<int>> ReconstructMasks;
2783   for (const auto &Mask : OrigReconstructMasks) {
2784     SmallVector<int> ReconstructMask;
2785     for (int M : Mask) {
2786       auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
2787         auto It = find_if(V, [M](auto A) { return A.second == M; });
2788         assert(It != V.end() && "Expected all entries in Mask");
2789         return std::distance(V.begin(), It);
2790       };
2791       if (M < 0)
2792         ReconstructMask.push_back(-1);
2793       else if (M < static_cast<int>(NumElts)) {
2794         ReconstructMask.push_back(FindIndex(V1, M));
2795       } else {
2796         ReconstructMask.push_back(NumElts + FindIndex(V2, M));
2797       }
2798     }
2799     ReconstructMasks.push_back(std::move(ReconstructMask));
2800   }
2801 
2802   // Calculate the masks needed for the new input shuffles, which get padded
2803   // with undef
2804   SmallVector<int> V1A, V1B, V2A, V2B;
2805   for (unsigned I = 0; I < V1.size(); I++) {
2806     V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
2807     V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
2808   }
2809   for (unsigned I = 0; I < V2.size(); I++) {
2810     V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
2811     V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
2812   }
2813   while (V1A.size() < NumElts) {
2814     V1A.push_back(PoisonMaskElem);
2815     V1B.push_back(PoisonMaskElem);
2816   }
2817   while (V2A.size() < NumElts) {
2818     V2A.push_back(PoisonMaskElem);
2819     V2B.push_back(PoisonMaskElem);
2820   }
2821 
2822   auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
2823     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2824     if (!SV)
2825       return C;
2826     return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
2827                                       ? TTI::SK_PermuteSingleSrc
2828                                       : TTI::SK_PermuteTwoSrc,
2829                                   VT, SV->getShuffleMask(), CostKind);
2830   };
2831   auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
2832     return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind);
2833   };
2834 
2835   // Get the costs of the shuffles + binops before and after with the new
2836   // shuffle masks.
2837   InstructionCost CostBefore =
2838       TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
2839       TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
2840   CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
2841                                 InstructionCost(0), AddShuffleCost);
2842   CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
2843                                 InstructionCost(0), AddShuffleCost);
2844 
2845   // The new binops will be unused for lanes past the used shuffle lengths.
2846   // These types attempt to get the correct cost for that from the target.
2847   FixedVectorType *Op0SmallVT =
2848       FixedVectorType::get(VT->getScalarType(), V1.size());
2849   FixedVectorType *Op1SmallVT =
2850       FixedVectorType::get(VT->getScalarType(), V2.size());
2851   InstructionCost CostAfter =
2852       TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
2853       TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
2854   CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
2855                                InstructionCost(0), AddShuffleMaskCost);
2856   std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
2857   CostAfter +=
2858       std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
2859                       InstructionCost(0), AddShuffleMaskCost);
2860 
2861   LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
2862   LLVM_DEBUG(dbgs() << "  CostBefore: " << CostBefore
2863                     << " vs CostAfter: " << CostAfter << "\n");
2864   if (CostBefore <= CostAfter)
2865     return false;
2866 
2867   // The cost model has passed, create the new instructions.
2868   auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
2869     auto *SV = dyn_cast<ShuffleVectorInst>(I);
2870     if (!SV)
2871       return I;
2872     if (isa<UndefValue>(SV->getOperand(1)))
2873       if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
2874         if (InputShuffles.contains(SSV))
2875           return SSV->getOperand(Op);
2876     return SV->getOperand(Op);
2877   };
2878   Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
2879   Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
2880                                              GetShuffleOperand(SVI0A, 1), V1A);
2881   Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
2882   Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
2883                                              GetShuffleOperand(SVI0B, 1), V1B);
2884   Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
2885   Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
2886                                              GetShuffleOperand(SVI1A, 1), V2A);
2887   Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
2888   Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
2889                                              GetShuffleOperand(SVI1B, 1), V2B);
2890   Builder.SetInsertPoint(Op0);
2891   Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
2892                                     NSV0A, NSV0B);
2893   if (auto *I = dyn_cast<Instruction>(NOp0))
2894     I->copyIRFlags(Op0, true);
2895   Builder.SetInsertPoint(Op1);
2896   Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
2897                                     NSV1A, NSV1B);
2898   if (auto *I = dyn_cast<Instruction>(NOp1))
2899     I->copyIRFlags(Op1, true);
2900 
2901   for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
2902     Builder.SetInsertPoint(Shuffles[S]);
2903     Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
2904     replaceValue(*Shuffles[S], *NSV);
2905   }
2906 
2907   Worklist.pushValue(NSV0A);
2908   Worklist.pushValue(NSV0B);
2909   Worklist.pushValue(NSV1A);
2910   Worklist.pushValue(NSV1B);
2911   for (auto *S : Shuffles)
2912     Worklist.add(S);
2913   return true;
2914 }
2915 
2916 /// Check if instruction depends on ZExt and this ZExt can be moved after the
2917 /// instruction. Move ZExt if it is profitable. For example:
2918 ///     logic(zext(x),y) -> zext(logic(x,trunc(y)))
2919 ///     lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
2920 /// Cost model calculations takes into account if zext(x) has other users and
2921 /// whether it can be propagated through them too.
2922 bool VectorCombine::shrinkType(Instruction &I) {
2923   Value *ZExted, *OtherOperand;
2924   if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
2925                                   m_Value(OtherOperand))) &&
2926       !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
2927     return false;
2928 
2929   Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
2930 
2931   auto *BigTy = cast<FixedVectorType>(I.getType());
2932   auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
2933   unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
2934 
2935   if (I.getOpcode() == Instruction::LShr) {
2936     // Check that the shift amount is less than the number of bits in the
2937     // smaller type. Otherwise, the smaller lshr will return a poison value.
2938     KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
2939     if (ShAmtKB.getMaxValue().uge(BW))
2940       return false;
2941   } else {
2942     // Check that the expression overall uses at most the same number of bits as
2943     // ZExted
2944     KnownBits KB = computeKnownBits(&I, *DL);
2945     if (KB.countMaxActiveBits() > BW)
2946       return false;
2947   }
2948 
2949   // Calculate costs of leaving current IR as it is and moving ZExt operation
2950   // later, along with adding truncates if needed
2951   InstructionCost ZExtCost = TTI.getCastInstrCost(
2952       Instruction::ZExt, BigTy, SmallTy,
2953       TargetTransformInfo::CastContextHint::None, CostKind);
2954   InstructionCost CurrentCost = ZExtCost;
2955   InstructionCost ShrinkCost = 0;
2956 
2957   // Calculate total cost and check that we can propagate through all ZExt users
2958   for (User *U : ZExtOperand->users()) {
2959     auto *UI = cast<Instruction>(U);
2960     if (UI == &I) {
2961       CurrentCost +=
2962           TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2963       ShrinkCost +=
2964           TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2965       ShrinkCost += ZExtCost;
2966       continue;
2967     }
2968 
2969     if (!Instruction::isBinaryOp(UI->getOpcode()))
2970       return false;
2971 
2972     // Check if we can propagate ZExt through its other users
2973     KnownBits KB = computeKnownBits(UI, *DL);
2974     if (KB.countMaxActiveBits() > BW)
2975       return false;
2976 
2977     CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
2978     ShrinkCost +=
2979         TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
2980     ShrinkCost += ZExtCost;
2981   }
2982 
2983   // If the other instruction operand is not a constant, we'll need to
2984   // generate a truncate instruction. So we have to adjust cost
2985   if (!isa<Constant>(OtherOperand))
2986     ShrinkCost += TTI.getCastInstrCost(
2987         Instruction::Trunc, SmallTy, BigTy,
2988         TargetTransformInfo::CastContextHint::None, CostKind);
2989 
2990   // If the cost of shrinking types and leaving the IR is the same, we'll lean
2991   // towards modifying the IR because shrinking opens opportunities for other
2992   // shrinking optimisations.
2993   if (ShrinkCost > CurrentCost)
2994     return false;
2995 
2996   Builder.SetInsertPoint(&I);
2997   Value *Op0 = ZExted;
2998   Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
2999   // Keep the order of operands the same
3000   if (I.getOperand(0) == OtherOperand)
3001     std::swap(Op0, Op1);
3002   Value *NewBinOp =
3003       Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
3004   cast<Instruction>(NewBinOp)->copyIRFlags(&I);
3005   cast<Instruction>(NewBinOp)->copyMetadata(I);
3006   Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
3007   replaceValue(I, *NewZExtr);
3008   return true;
3009 }
3010 
3011 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
3012 /// shuffle (DstVec, SrcVec, Mask)
3013 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
3014   Value *DstVec, *SrcVec;
3015   uint64_t ExtIdx, InsIdx;
3016   if (!match(&I,
3017              m_InsertElt(m_Value(DstVec),
3018                          m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
3019                          m_ConstantInt(InsIdx))))
3020     return false;
3021 
3022   auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
3023   if (!VecTy || SrcVec->getType() != VecTy)
3024     return false;
3025 
3026   unsigned NumElts = VecTy->getNumElements();
3027   if (ExtIdx >= NumElts || InsIdx >= NumElts)
3028     return false;
3029 
3030   // Insertion into poison is a cheaper single operand shuffle.
3031   TargetTransformInfo::ShuffleKind SK;
3032   SmallVector<int> Mask(NumElts, PoisonMaskElem);
3033   if (isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
3034     SK = TargetTransformInfo::SK_PermuteSingleSrc;
3035     Mask[InsIdx] = ExtIdx;
3036     std::swap(DstVec, SrcVec);
3037   } else {
3038     SK = TargetTransformInfo::SK_PermuteTwoSrc;
3039     std::iota(Mask.begin(), Mask.end(), 0);
3040     Mask[InsIdx] = ExtIdx + NumElts;
3041   }
3042 
3043   // Cost
3044   auto *Ins = cast<InsertElementInst>(&I);
3045   auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
3046   InstructionCost InsCost =
3047       TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx);
3048   InstructionCost ExtCost =
3049       TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx);
3050   InstructionCost OldCost = ExtCost + InsCost;
3051 
3052   InstructionCost NewCost = TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0,
3053                                                nullptr, {DstVec, SrcVec});
3054   if (!Ext->hasOneUse())
3055     NewCost += ExtCost;
3056 
3057   LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I
3058                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
3059                     << "\n");
3060 
3061   if (OldCost < NewCost)
3062     return false;
3063 
3064   // Canonicalize undef param to RHS to help further folds.
3065   if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
3066     ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
3067     std::swap(DstVec, SrcVec);
3068   }
3069 
3070   Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
3071   replaceValue(I, *Shuf);
3072 
3073   return true;
3074 }
3075 
3076 /// This is the entry point for all transforms. Pass manager differences are
3077 /// handled in the callers of this function.
3078 bool VectorCombine::run() {
3079   if (DisableVectorCombine)
3080     return false;
3081 
3082   // Don't attempt vectorization if the target does not support vectors.
3083   if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
3084     return false;
3085 
3086   LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
3087 
3088   bool MadeChange = false;
3089   auto FoldInst = [this, &MadeChange](Instruction &I) {
3090     Builder.SetInsertPoint(&I);
3091     bool IsVectorType = isa<VectorType>(I.getType());
3092     bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
3093     auto Opcode = I.getOpcode();
3094 
3095     LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
3096 
3097     // These folds should be beneficial regardless of when this pass is run
3098     // in the optimization pipeline.
3099     // The type checking is for run-time efficiency. We can avoid wasting time
3100     // dispatching to folding functions if there's no chance of matching.
3101     if (IsFixedVectorType) {
3102       switch (Opcode) {
3103       case Instruction::InsertElement:
3104         MadeChange |= vectorizeLoadInsert(I);
3105         break;
3106       case Instruction::ShuffleVector:
3107         MadeChange |= widenSubvectorLoad(I);
3108         break;
3109       default:
3110         break;
3111       }
3112     }
3113 
3114     // This transform works with scalable and fixed vectors
3115     // TODO: Identify and allow other scalable transforms
3116     if (IsVectorType) {
3117       MadeChange |= scalarizeBinopOrCmp(I);
3118       MadeChange |= scalarizeLoadExtract(I);
3119       MadeChange |= scalarizeVPIntrinsic(I);
3120     }
3121 
3122     if (Opcode == Instruction::Store)
3123       MadeChange |= foldSingleElementStore(I);
3124 
3125     // If this is an early pipeline invocation of this pass, we are done.
3126     if (TryEarlyFoldsOnly)
3127       return;
3128 
3129     // Otherwise, try folds that improve codegen but may interfere with
3130     // early IR canonicalizations.
3131     // The type checking is for run-time efficiency. We can avoid wasting time
3132     // dispatching to folding functions if there's no chance of matching.
3133     if (IsFixedVectorType) {
3134       switch (Opcode) {
3135       case Instruction::InsertElement:
3136         MadeChange |= foldInsExtFNeg(I);
3137         MadeChange |= foldInsExtVectorToShuffle(I);
3138         break;
3139       case Instruction::ShuffleVector:
3140         MadeChange |= foldPermuteOfBinops(I);
3141         MadeChange |= foldShuffleOfBinops(I);
3142         MadeChange |= foldShuffleOfCastops(I);
3143         MadeChange |= foldShuffleOfShuffles(I);
3144         MadeChange |= foldShuffleOfIntrinsics(I);
3145         MadeChange |= foldSelectShuffle(I);
3146         MadeChange |= foldShuffleToIdentity(I);
3147         break;
3148       case Instruction::BitCast:
3149         MadeChange |= foldBitcastShuffle(I);
3150         break;
3151       default:
3152         MadeChange |= shrinkType(I);
3153         break;
3154       }
3155     } else {
3156       switch (Opcode) {
3157       case Instruction::Call:
3158         MadeChange |= foldShuffleFromReductions(I);
3159         MadeChange |= foldCastFromReductions(I);
3160         break;
3161       case Instruction::ICmp:
3162       case Instruction::FCmp:
3163         MadeChange |= foldExtractExtract(I);
3164         break;
3165       case Instruction::Or:
3166         MadeChange |= foldConcatOfBoolMasks(I);
3167         [[fallthrough]];
3168       default:
3169         if (Instruction::isBinaryOp(Opcode)) {
3170           MadeChange |= foldExtractExtract(I);
3171           MadeChange |= foldExtractedCmps(I);
3172         }
3173         break;
3174       }
3175     }
3176   };
3177 
3178   for (BasicBlock &BB : F) {
3179     // Ignore unreachable basic blocks.
3180     if (!DT.isReachableFromEntry(&BB))
3181       continue;
3182     // Use early increment range so that we can erase instructions in loop.
3183     for (Instruction &I : make_early_inc_range(BB)) {
3184       if (I.isDebugOrPseudoInst())
3185         continue;
3186       FoldInst(I);
3187     }
3188   }
3189 
3190   while (!Worklist.isEmpty()) {
3191     Instruction *I = Worklist.removeOne();
3192     if (!I)
3193       continue;
3194 
3195     if (isInstructionTriviallyDead(I)) {
3196       eraseInstruction(*I);
3197       continue;
3198     }
3199 
3200     FoldInst(*I);
3201   }
3202 
3203   return MadeChange;
3204 }
3205 
3206 PreservedAnalyses VectorCombinePass::run(Function &F,
3207                                          FunctionAnalysisManager &FAM) {
3208   auto &AC = FAM.getResult<AssumptionAnalysis>(F);
3209   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
3210   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
3211   AAResults &AA = FAM.getResult<AAManager>(F);
3212   const DataLayout *DL = &F.getDataLayout();
3213   VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
3214                          TryEarlyFoldsOnly);
3215   if (!Combiner.run())
3216     return PreservedAnalyses::all();
3217   PreservedAnalyses PA;
3218   PA.preserveSet<CFGAnalyses>();
3219   return PA;
3220 }
3221