xref: /llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp (revision 45c01e8a33bbb1790ea16577e47b1e6a34fa1548)
1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts vector operations into scalar operations (or, optionally,
10 // operations on smaller vector widths), in order to expose optimization
11 // opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar/Scalarizer.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/Argument.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/InstVisitor.h"
32 #include "llvm/IR/InstrTypes.h"
33 #include "llvm/IR/Instruction.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Intrinsics.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <cassert>
44 #include <cstdint>
45 #include <iterator>
46 #include <map>
47 #include <utility>
48 
49 using namespace llvm;
50 
51 #define DEBUG_TYPE "scalarizer"
52 
53 namespace {
54 
55 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
56   BasicBlock *BB = Itr->getParent();
57   if (isa<PHINode>(Itr))
58     Itr = BB->getFirstInsertionPt();
59   if (Itr != BB->end())
60     Itr = skipDebugIntrinsics(Itr);
61   return Itr;
62 }
63 
64 // Used to store the scattered form of a vector.
65 using ValueVector = SmallVector<Value *, 8>;
66 
67 // Used to map a vector Value and associated type to its scattered form.
68 // The associated type is only non-null for pointer values that are "scattered"
69 // when used as pointer operands to load or store.
70 //
71 // We use std::map because we want iterators to persist across insertion and
72 // because the values are relatively large.
73 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
74 
75 // Lists Instructions that have been replaced with scalar implementations,
76 // along with a pointer to their scattered forms.
77 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
78 
79 struct VectorSplit {
80   // The type of the vector.
81   FixedVectorType *VecTy = nullptr;
82 
83   // The number of elements packed in a fragment (other than the remainder).
84   unsigned NumPacked = 0;
85 
86   // The number of fragments (scalars or smaller vectors) into which the vector
87   // shall be split.
88   unsigned NumFragments = 0;
89 
90   // The type of each complete fragment.
91   Type *SplitTy = nullptr;
92 
93   // The type of the remainder (last) fragment; null if all fragments are
94   // complete.
95   Type *RemainderTy = nullptr;
96 
97   Type *getFragmentType(unsigned I) const {
98     return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
99   }
100 };
101 
102 // Provides a very limited vector-like interface for lazily accessing one
103 // component of a scattered vector or vector pointer.
104 class Scatterer {
105 public:
106   Scatterer() = default;
107 
108   // Scatter V into Size components.  If new instructions are needed,
109   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
110   // the results.
111   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
112             const VectorSplit &VS, ValueVector *cachePtr = nullptr);
113 
114   // Return component I, creating a new Value for it if necessary.
115   Value *operator[](unsigned I);
116 
117   // Return the number of components.
118   unsigned size() const { return VS.NumFragments; }
119 
120 private:
121   BasicBlock *BB;
122   BasicBlock::iterator BBI;
123   Value *V;
124   VectorSplit VS;
125   bool IsPointer;
126   ValueVector *CachePtr;
127   ValueVector Tmp;
128 };
129 
130 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
131 // called Name that compares X and Y in the same way as FCI.
132 struct FCmpSplitter {
133   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
134 
135   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
136                     const Twine &Name) const {
137     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
138   }
139 
140   FCmpInst &FCI;
141 };
142 
143 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
144 // called Name that compares X and Y in the same way as ICI.
145 struct ICmpSplitter {
146   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
147 
148   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
149                     const Twine &Name) const {
150     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
151   }
152 
153   ICmpInst &ICI;
154 };
155 
156 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create
157 // a unary operator like UO called Name with operand X.
158 struct UnarySplitter {
159   UnarySplitter(UnaryOperator &uo) : UO(uo) {}
160 
161   Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
162     return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
163   }
164 
165   UnaryOperator &UO;
166 };
167 
168 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
169 // a binary operator like BO called Name with operands X and Y.
170 struct BinarySplitter {
171   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
172 
173   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
174                     const Twine &Name) const {
175     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
176   }
177 
178   BinaryOperator &BO;
179 };
180 
181 // Information about a load or store that we're scalarizing.
182 struct VectorLayout {
183   VectorLayout() = default;
184 
185   // Return the alignment of fragment Frag.
186   Align getFragmentAlign(unsigned Frag) {
187     return commonAlignment(VecAlign, Frag * SplitSize);
188   }
189 
190   // The split of the underlying vector type.
191   VectorSplit VS;
192 
193   // The alignment of the vector.
194   Align VecAlign;
195 
196   // The size of each (non-remainder) fragment in bytes.
197   uint64_t SplitSize = 0;
198 };
199 
200 static bool isStructOfMatchingFixedVectors(Type *Ty) {
201   if (!isa<StructType>(Ty))
202     return false;
203   unsigned StructSize = Ty->getNumContainedTypes();
204   if (StructSize < 1)
205     return false;
206   FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0));
207   if (!VecTy)
208     return false;
209   unsigned VecSize = VecTy->getNumElements();
210   for (unsigned I = 1; I < StructSize; I++) {
211     VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I));
212     if (!VecTy || VecSize != VecTy->getNumElements())
213       return false;
214   }
215   return true;
216 }
217 
218 /// Concatenate the given fragments to a single vector value of the type
219 /// described in @p VS.
220 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
221                           const VectorSplit &VS, Twine Name) {
222   unsigned NumElements = VS.VecTy->getNumElements();
223   SmallVector<int> ExtendMask;
224   SmallVector<int> InsertMask;
225 
226   if (VS.NumPacked > 1) {
227     // Prepare the shufflevector masks once and re-use them for all
228     // fragments.
229     ExtendMask.resize(NumElements, -1);
230     for (unsigned I = 0; I < VS.NumPacked; ++I)
231       ExtendMask[I] = I;
232 
233     InsertMask.resize(NumElements);
234     for (unsigned I = 0; I < NumElements; ++I)
235       InsertMask[I] = I;
236   }
237 
238   Value *Res = PoisonValue::get(VS.VecTy);
239   for (unsigned I = 0; I < VS.NumFragments; ++I) {
240     Value *Fragment = Fragments[I];
241 
242     unsigned NumPacked = VS.NumPacked;
243     if (I == VS.NumFragments - 1 && VS.RemainderTy) {
244       if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy))
245         NumPacked = RemVecTy->getNumElements();
246       else
247         NumPacked = 1;
248     }
249 
250     if (NumPacked == 1) {
251       Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked,
252                                         Name + ".upto" + Twine(I));
253     } else {
254       Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask);
255       if (I == 0) {
256         Res = Fragment;
257       } else {
258         for (unsigned J = 0; J < NumPacked; ++J)
259           InsertMask[I * VS.NumPacked + J] = NumElements + J;
260         Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask,
261                                           Name + ".upto" + Twine(I));
262         for (unsigned J = 0; J < NumPacked; ++J)
263           InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
264       }
265     }
266   }
267 
268   return Res;
269 }
270 
271 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
272 public:
273   ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI,
274                     ScalarizerPassOptions Options)
275       : DT(DT), TTI(TTI),
276         ScalarizeVariableInsertExtract(Options.ScalarizeVariableInsertExtract),
277         ScalarizeLoadStore(Options.ScalarizeLoadStore),
278         ScalarizeMinBits(Options.ScalarizeMinBits) {}
279 
280   bool visit(Function &F);
281 
282   // InstVisitor methods.  They return true if the instruction was scalarized,
283   // false if nothing changed.
284   bool visitInstruction(Instruction &I) { return false; }
285   bool visitSelectInst(SelectInst &SI);
286   bool visitICmpInst(ICmpInst &ICI);
287   bool visitFCmpInst(FCmpInst &FCI);
288   bool visitUnaryOperator(UnaryOperator &UO);
289   bool visitBinaryOperator(BinaryOperator &BO);
290   bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
291   bool visitCastInst(CastInst &CI);
292   bool visitBitCastInst(BitCastInst &BCI);
293   bool visitInsertElementInst(InsertElementInst &IEI);
294   bool visitExtractElementInst(ExtractElementInst &EEI);
295   bool visitExtractValueInst(ExtractValueInst &EVI);
296   bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
297   bool visitPHINode(PHINode &PHI);
298   bool visitLoadInst(LoadInst &LI);
299   bool visitStoreInst(StoreInst &SI);
300   bool visitCallInst(CallInst &ICI);
301   bool visitFreezeInst(FreezeInst &FI);
302 
303 private:
304   Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
305   void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
306   void replaceUses(Instruction *Op, Value *CV);
307   bool canTransferMetadata(unsigned Kind);
308   void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
309   std::optional<VectorSplit> getVectorSplit(Type *Ty);
310   std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
311                                               const DataLayout &DL);
312   bool finish();
313 
314   template<typename T> bool splitUnary(Instruction &, const T &);
315   template<typename T> bool splitBinary(Instruction &, const T &);
316 
317   bool splitCall(CallInst &CI);
318 
319   ScatterMap Scattered;
320   GatherList Gathered;
321   bool Scalarized;
322 
323   SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
324 
325   DominatorTree *DT;
326   const TargetTransformInfo *TTI;
327 
328   const bool ScalarizeVariableInsertExtract;
329   const bool ScalarizeLoadStore;
330   const unsigned ScalarizeMinBits;
331 };
332 
333 class ScalarizerLegacyPass : public FunctionPass {
334 public:
335   static char ID;
336   ScalarizerPassOptions Options;
337   ScalarizerLegacyPass() : FunctionPass(ID), Options() {}
338   ScalarizerLegacyPass(const ScalarizerPassOptions &Options);
339   bool runOnFunction(Function &F) override;
340   void getAnalysisUsage(AnalysisUsage &AU) const override;
341 };
342 
343 } // end anonymous namespace
344 
345 ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options)
346     : FunctionPass(ID), Options(Options) {}
347 
348 void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
349   AU.addRequired<DominatorTreeWrapperPass>();
350   AU.addRequired<TargetTransformInfoWrapperPass>();
351   AU.addPreserved<DominatorTreeWrapperPass>();
352 }
353 
354 char ScalarizerLegacyPass::ID = 0;
355 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
356                       "Scalarize vector operations", false, false)
357 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
358 INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
359                     "Scalarize vector operations", false, false)
360 
361 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
362                      const VectorSplit &VS, ValueVector *cachePtr)
363     : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
364   IsPointer = V->getType()->isPointerTy();
365   if (!CachePtr) {
366     Tmp.resize(VS.NumFragments, nullptr);
367   } else {
368     assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
369             IsPointer) &&
370            "Inconsistent vector sizes");
371     if (VS.NumFragments > CachePtr->size())
372       CachePtr->resize(VS.NumFragments, nullptr);
373   }
374 }
375 
376 // Return fragment Frag, creating a new Value for it if necessary.
377 Value *Scatterer::operator[](unsigned Frag) {
378   ValueVector &CV = CachePtr ? *CachePtr : Tmp;
379   // Try to reuse a previous value.
380   if (CV[Frag])
381     return CV[Frag];
382   IRBuilder<> Builder(BB, BBI);
383   if (IsPointer) {
384     if (Frag == 0)
385       CV[Frag] = V;
386     else
387       CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag,
388                                             V->getName() + ".i" + Twine(Frag));
389     return CV[Frag];
390   }
391 
392   Type *FragmentTy = VS.getFragmentType(Frag);
393 
394   if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) {
395     SmallVector<int> Mask;
396     for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
397       Mask.push_back(Frag * VS.NumPacked + J);
398     CV[Frag] =
399         Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask,
400                                     V->getName() + ".i" + Twine(Frag));
401   } else {
402     // Search through a chain of InsertElementInsts looking for element Frag.
403     // Record other elements in the cache.  The new V is still suitable
404     // for all uncached indices.
405     while (true) {
406       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
407       if (!Insert)
408         break;
409       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
410       if (!Idx)
411         break;
412       unsigned J = Idx->getZExtValue();
413       V = Insert->getOperand(0);
414       if (Frag * VS.NumPacked == J) {
415         CV[Frag] = Insert->getOperand(1);
416         return CV[Frag];
417       }
418 
419       if (VS.NumPacked == 1 && !CV[J]) {
420         // Only cache the first entry we find for each index we're not actively
421         // searching for. This prevents us from going too far up the chain and
422         // caching incorrect entries.
423         CV[J] = Insert->getOperand(1);
424       }
425     }
426     CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked,
427                                             V->getName() + ".i" + Twine(Frag));
428   }
429 
430   return CV[Frag];
431 }
432 
433 bool ScalarizerLegacyPass::runOnFunction(Function &F) {
434   if (skipFunction(F))
435     return false;
436 
437   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
438   const TargetTransformInfo *TTI =
439       &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
440   ScalarizerVisitor Impl(DT, TTI, Options);
441   return Impl.visit(F);
442 }
443 
444 FunctionPass *llvm::createScalarizerPass(const ScalarizerPassOptions &Options) {
445   return new ScalarizerLegacyPass(Options);
446 }
447 
448 bool ScalarizerVisitor::visit(Function &F) {
449   assert(Gathered.empty() && Scattered.empty());
450 
451   Scalarized = false;
452 
453   // To ensure we replace gathered components correctly we need to do an ordered
454   // traversal of the basic blocks in the function.
455   ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
456   for (BasicBlock *BB : RPOT) {
457     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
458       Instruction *I = &*II;
459       bool Done = InstVisitor::visit(I);
460       ++II;
461       if (Done && I->getType()->isVoidTy())
462         I->eraseFromParent();
463     }
464   }
465   return finish();
466 }
467 
468 // Return a scattered form of V that can be accessed by Point.  V must be a
469 // vector or a pointer to a vector.
470 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
471                                      const VectorSplit &VS) {
472   if (Argument *VArg = dyn_cast<Argument>(V)) {
473     // Put the scattered form of arguments in the entry block,
474     // so that it can be used everywhere.
475     Function *F = VArg->getParent();
476     BasicBlock *BB = &F->getEntryBlock();
477     return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
478   }
479   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
480     // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
481     // nodes in predecessors. If those predecessors are unreachable from entry,
482     // then the IR in those blocks could have unexpected properties resulting in
483     // infinite loops in Scatterer::operator[]. By simply treating values
484     // originating from instructions in unreachable blocks as undef we do not
485     // need to analyse them further.
486     if (!DT->isReachableFromEntry(VOp->getParent()))
487       return Scatterer(Point->getParent(), Point->getIterator(),
488                        PoisonValue::get(V->getType()), VS);
489     // Put the scattered form of an instruction directly after the
490     // instruction, skipping over PHI nodes and debug intrinsics.
491     BasicBlock *BB = VOp->getParent();
492     return Scatterer(
493         BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS,
494         &Scattered[{V, VS.SplitTy}]);
495   }
496   // In the fallback case, just put the scattered before Point and
497   // keep the result local to Point.
498   return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
499 }
500 
501 // Replace Op with the gathered form of the components in CV.  Defer the
502 // deletion of Op and creation of the gathered form to the end of the pass,
503 // so that we can avoid creating the gathered form if all uses of Op are
504 // replaced with uses of CV.
505 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
506                                const VectorSplit &VS) {
507   transferMetadataAndIRFlags(Op, CV);
508 
509   // If we already have a scattered form of Op (created from ExtractElements
510   // of Op itself), replace them with the new form.
511   ValueVector &SV = Scattered[{Op, VS.SplitTy}];
512   if (!SV.empty()) {
513     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
514       Value *V = SV[I];
515       if (V == nullptr || SV[I] == CV[I])
516         continue;
517 
518       Instruction *Old = cast<Instruction>(V);
519       if (isa<Instruction>(CV[I]))
520         CV[I]->takeName(Old);
521       Old->replaceAllUsesWith(CV[I]);
522       PotentiallyDeadInstrs.emplace_back(Old);
523     }
524   }
525   SV = CV;
526   Gathered.push_back(GatherList::value_type(Op, &SV));
527 }
528 
529 // Replace Op with CV and collect Op has a potentially dead instruction.
530 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
531   if (CV != Op) {
532     Op->replaceAllUsesWith(CV);
533     PotentiallyDeadInstrs.emplace_back(Op);
534     Scalarized = true;
535   }
536 }
537 
538 // Return true if it is safe to transfer the given metadata tag from
539 // vector to scalar instructions.
540 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
541   return (Tag == LLVMContext::MD_tbaa
542           || Tag == LLVMContext::MD_fpmath
543           || Tag == LLVMContext::MD_tbaa_struct
544           || Tag == LLVMContext::MD_invariant_load
545           || Tag == LLVMContext::MD_alias_scope
546           || Tag == LLVMContext::MD_noalias
547           || Tag == LLVMContext::MD_mem_parallel_loop_access
548           || Tag == LLVMContext::MD_access_group);
549 }
550 
551 // Transfer metadata from Op to the instructions in CV if it is known
552 // to be safe to do so.
553 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
554                                                    const ValueVector &CV) {
555   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
556   Op->getAllMetadataOtherThanDebugLoc(MDs);
557   for (Value *V : CV) {
558     if (Instruction *New = dyn_cast<Instruction>(V)) {
559       for (const auto &MD : MDs)
560         if (canTransferMetadata(MD.first))
561           New->setMetadata(MD.first, MD.second);
562       New->copyIRFlags(Op);
563       if (Op->getDebugLoc() && !New->getDebugLoc())
564         New->setDebugLoc(Op->getDebugLoc());
565     }
566   }
567 }
568 
569 // Determine how Ty is split, if at all.
570 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
571   VectorSplit Split;
572   Split.VecTy = dyn_cast<FixedVectorType>(Ty);
573   if (!Split.VecTy)
574     return {};
575 
576   unsigned NumElems = Split.VecTy->getNumElements();
577   Type *ElemTy = Split.VecTy->getElementType();
578 
579   if (NumElems == 1 || ElemTy->isPointerTy() ||
580       2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
581     Split.NumPacked = 1;
582     Split.NumFragments = NumElems;
583     Split.SplitTy = ElemTy;
584   } else {
585     Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
586     if (Split.NumPacked >= NumElems)
587       return {};
588 
589     Split.NumFragments = divideCeil(NumElems, Split.NumPacked);
590     Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked);
591 
592     unsigned RemainderElems = NumElems % Split.NumPacked;
593     if (RemainderElems > 1)
594       Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems);
595     else if (RemainderElems == 1)
596       Split.RemainderTy = ElemTy;
597   }
598 
599   return Split;
600 }
601 
602 // Try to fill in Layout from Ty, returning true on success.  Alignment is
603 // the alignment of the vector, or std::nullopt if the ABI default should be
604 // used.
605 std::optional<VectorLayout>
606 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
607                                    const DataLayout &DL) {
608   std::optional<VectorSplit> VS = getVectorSplit(Ty);
609   if (!VS)
610     return {};
611 
612   VectorLayout Layout;
613   Layout.VS = *VS;
614   // Check that we're dealing with full-byte fragments.
615   if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) ||
616       (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy)))
617     return {};
618   Layout.VecAlign = Alignment;
619   Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy);
620   return Layout;
621 }
622 
623 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
624 // to create an instruction like I with operand X and name Name.
625 template<typename Splitter>
626 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
627   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
628   if (!VS)
629     return false;
630 
631   std::optional<VectorSplit> OpVS;
632   if (I.getOperand(0)->getType() == I.getType()) {
633     OpVS = VS;
634   } else {
635     OpVS = getVectorSplit(I.getOperand(0)->getType());
636     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
637       return false;
638   }
639 
640   IRBuilder<> Builder(&I);
641   Scatterer Op = scatter(&I, I.getOperand(0), *OpVS);
642   assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
643   ValueVector Res;
644   Res.resize(VS->NumFragments);
645   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
646     Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
647   gather(&I, Res, *VS);
648   return true;
649 }
650 
651 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
652 // to create an instruction like I with operands X and Y and name Name.
653 template<typename Splitter>
654 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
655   std::optional<VectorSplit> VS = getVectorSplit(I.getType());
656   if (!VS)
657     return false;
658 
659   std::optional<VectorSplit> OpVS;
660   if (I.getOperand(0)->getType() == I.getType()) {
661     OpVS = VS;
662   } else {
663     OpVS = getVectorSplit(I.getOperand(0)->getType());
664     if (!OpVS || VS->NumPacked != OpVS->NumPacked)
665       return false;
666   }
667 
668   IRBuilder<> Builder(&I);
669   Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS);
670   Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS);
671   assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
672   assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
673   ValueVector Res;
674   Res.resize(VS->NumFragments);
675   for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
676     Value *Op0 = VOp0[Frag];
677     Value *Op1 = VOp1[Frag];
678     Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
679   }
680   gather(&I, Res, *VS);
681   return true;
682 }
683 
684 /// If a call to a vector typed intrinsic function, split into a scalar call per
685 /// element if possible for the intrinsic.
686 bool ScalarizerVisitor::splitCall(CallInst &CI) {
687   Type *CallType = CI.getType();
688   bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType);
689   std::optional<VectorSplit> VS;
690   if (AreAllVectorsOfMatchingSize)
691     VS = getVectorSplit(CallType->getContainedType(0));
692   else
693     VS = getVectorSplit(CallType);
694   if (!VS)
695     return false;
696 
697   Function *F = CI.getCalledFunction();
698   if (!F)
699     return false;
700 
701   Intrinsic::ID ID = F->getIntrinsicID();
702 
703   if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
704     return false;
705 
706   // unsigned NumElems = VT->getNumElements();
707   unsigned NumArgs = CI.arg_size();
708 
709   ValueVector ScalarOperands(NumArgs);
710   SmallVector<Scatterer, 8> Scattered(NumArgs);
711   SmallVector<int> OverloadIdx(NumArgs, -1);
712 
713   SmallVector<llvm::Type *, 3> Tys;
714   // Add return type if intrinsic is overloaded on it.
715   if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
716     Tys.push_back(VS->SplitTy);
717 
718   if (AreAllVectorsOfMatchingSize) {
719     for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
720       std::optional<VectorSplit> CurrVS =
721           getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I)));
722       // This case does not seem to happen, but it is possible for
723       // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit
724       // is not returned and we will bailout of handling this call.
725       // The secondary bailout case is if NumPacked does not match.
726       // This can happen if ScalarizeMinBits is not set to the default.
727       // This means with certain ScalarizeMinBits intrinsics like frexp
728       // will only scalarize when the struct elements have the same bitness.
729       if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
730         return false;
731       if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
732         Tys.push_back(CurrVS->SplitTy);
733     }
734   }
735   // Assumes that any vector type has the same number of elements as the return
736   // vector type, which is true for all current intrinsics.
737   for (unsigned I = 0; I != NumArgs; ++I) {
738     Value *OpI = CI.getOperand(I);
739     if ([[maybe_unused]] auto *OpVecTy =
740             dyn_cast<FixedVectorType>(OpI->getType())) {
741       assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
742       std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType());
743       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
744         // The natural split of the operand doesn't match the result. This could
745         // happen if the vector elements are different and the ScalarizeMinBits
746         // option is used.
747         //
748         // We could in principle handle this case as well, at the cost of
749         // complicating the scattering machinery to support multiple scattering
750         // granularities for a single value.
751         return false;
752       }
753 
754       Scattered[I] = scatter(&CI, OpI, *OpVS);
755       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
756         OverloadIdx[I] = Tys.size();
757         Tys.push_back(OpVS->SplitTy);
758       }
759     } else {
760       ScalarOperands[I] = OpI;
761       if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
762         Tys.push_back(OpI->getType());
763     }
764   }
765 
766   ValueVector Res(VS->NumFragments);
767   ValueVector ScalarCallOps(NumArgs);
768 
769   Function *NewIntrin =
770       Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys);
771   IRBuilder<> Builder(&CI);
772 
773   // Perform actual scalarization, taking care to preserve any scalar operands.
774   for (unsigned I = 0; I < VS->NumFragments; ++I) {
775     bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
776     ScalarCallOps.clear();
777 
778     if (IsRemainder)
779       Tys[0] = VS->RemainderTy;
780 
781     for (unsigned J = 0; J != NumArgs; ++J) {
782       if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
783         ScalarCallOps.push_back(ScalarOperands[J]);
784       } else {
785         ScalarCallOps.push_back(Scattered[J][I]);
786         if (IsRemainder && OverloadIdx[J] >= 0)
787           Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
788       }
789     }
790 
791     if (IsRemainder)
792       NewIntrin = Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys);
793 
794     Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps,
795                                 CI.getName() + ".i" + Twine(I));
796   }
797 
798   gather(&CI, Res, *VS);
799   return true;
800 }
801 
802 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
803   std::optional<VectorSplit> VS = getVectorSplit(SI.getType());
804   if (!VS)
805     return false;
806 
807   std::optional<VectorSplit> CondVS;
808   if (isa<FixedVectorType>(SI.getCondition()->getType())) {
809     CondVS = getVectorSplit(SI.getCondition()->getType());
810     if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
811       // This happens when ScalarizeMinBits is used.
812       return false;
813     }
814   }
815 
816   IRBuilder<> Builder(&SI);
817   Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS);
818   Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS);
819   assert(VOp1.size() == VS->NumFragments && "Mismatched select");
820   assert(VOp2.size() == VS->NumFragments && "Mismatched select");
821   ValueVector Res;
822   Res.resize(VS->NumFragments);
823 
824   if (CondVS) {
825     Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS);
826     assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
827     for (unsigned I = 0; I < VS->NumFragments; ++I) {
828       Value *Op0 = VOp0[I];
829       Value *Op1 = VOp1[I];
830       Value *Op2 = VOp2[I];
831       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
832                                     SI.getName() + ".i" + Twine(I));
833     }
834   } else {
835     Value *Op0 = SI.getOperand(0);
836     for (unsigned I = 0; I < VS->NumFragments; ++I) {
837       Value *Op1 = VOp1[I];
838       Value *Op2 = VOp2[I];
839       Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
840                                     SI.getName() + ".i" + Twine(I));
841     }
842   }
843   gather(&SI, Res, *VS);
844   return true;
845 }
846 
847 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
848   return splitBinary(ICI, ICmpSplitter(ICI));
849 }
850 
851 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
852   return splitBinary(FCI, FCmpSplitter(FCI));
853 }
854 
855 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
856   return splitUnary(UO, UnarySplitter(UO));
857 }
858 
859 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
860   return splitBinary(BO, BinarySplitter(BO));
861 }
862 
863 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
864   std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType());
865   if (!VS)
866     return false;
867 
868   IRBuilder<> Builder(&GEPI);
869   unsigned NumIndices = GEPI.getNumIndices();
870 
871   // The base pointer and indices might be scalar even if it's a vector GEP.
872   SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
873   SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
874 
875   for (unsigned I = 0; I < 1 + NumIndices; ++I) {
876     if (auto *VecTy =
877             dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) {
878       std::optional<VectorSplit> OpVS = getVectorSplit(VecTy);
879       if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
880         // This can happen when ScalarizeMinBits is used.
881         return false;
882       }
883       ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS);
884     } else {
885       ScalarOps[I] = GEPI.getOperand(I);
886     }
887   }
888 
889   ValueVector Res;
890   Res.resize(VS->NumFragments);
891   for (unsigned I = 0; I < VS->NumFragments; ++I) {
892     SmallVector<Value *, 8> SplitOps;
893     SplitOps.resize(1 + NumIndices);
894     for (unsigned J = 0; J < 1 + NumIndices; ++J) {
895       if (ScalarOps[J])
896         SplitOps[J] = ScalarOps[J];
897       else
898         SplitOps[J] = ScatterOps[J][I];
899     }
900     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0],
901                                ArrayRef(SplitOps).drop_front(),
902                                GEPI.getName() + ".i" + Twine(I));
903     if (GEPI.isInBounds())
904       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
905         NewGEPI->setIsInBounds();
906   }
907   gather(&GEPI, Res, *VS);
908   return true;
909 }
910 
911 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
912   std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy());
913   if (!DestVS)
914     return false;
915 
916   std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy());
917   if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
918     return false;
919 
920   IRBuilder<> Builder(&CI);
921   Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS);
922   assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
923   ValueVector Res;
924   Res.resize(DestVS->NumFragments);
925   for (unsigned I = 0; I < DestVS->NumFragments; ++I)
926     Res[I] =
927         Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I),
928                            CI.getName() + ".i" + Twine(I));
929   gather(&CI, Res, *DestVS);
930   return true;
931 }
932 
933 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
934   std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy());
935   std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy());
936   if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
937     return false;
938 
939   const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
940 
941   // Vectors of pointers are always fully scalarized.
942   assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
943 
944   IRBuilder<> Builder(&BCI);
945   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS);
946   ValueVector Res;
947   Res.resize(DstVS->NumFragments);
948 
949   unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
950   unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
951 
952   if (isPointerTy || DstSplitBits == SrcSplitBits) {
953     assert(DstVS->NumFragments == SrcVS->NumFragments);
954     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
955       Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I),
956                                      BCI.getName() + ".i" + Twine(I));
957     }
958   } else if (SrcSplitBits % DstSplitBits == 0) {
959     // Convert each source fragment to the same-sized destination vector and
960     // then scatter the result to the destination.
961     VectorSplit MidVS;
962     MidVS.NumPacked = DstVS->NumPacked;
963     MidVS.NumFragments = SrcSplitBits / DstSplitBits;
964     MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(),
965                                        MidVS.NumPacked * MidVS.NumFragments);
966     MidVS.SplitTy = DstVS->SplitTy;
967 
968     unsigned ResI = 0;
969     for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
970       Value *V = Op0[I];
971 
972       // Look through any existing bitcasts before converting to <N x t2>.
973       // In the best case, the resulting conversion might be a no-op.
974       Instruction *VI;
975       while ((VI = dyn_cast<Instruction>(V)) &&
976              VI->getOpcode() == Instruction::BitCast)
977         V = VI->getOperand(0);
978 
979       V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast");
980 
981       Scatterer Mid = scatter(&BCI, V, MidVS);
982       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
983         Res[ResI++] = Mid[J];
984     }
985   } else if (DstSplitBits % SrcSplitBits == 0) {
986     // Gather enough source fragments to make up a destination fragment and
987     // then convert to the destination type.
988     VectorSplit MidVS;
989     MidVS.NumFragments = DstSplitBits / SrcSplitBits;
990     MidVS.NumPacked = SrcVS->NumPacked;
991     MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(),
992                                        MidVS.NumPacked * MidVS.NumFragments);
993     MidVS.SplitTy = SrcVS->SplitTy;
994 
995     unsigned SrcI = 0;
996     SmallVector<Value *, 8> ConcatOps;
997     ConcatOps.resize(MidVS.NumFragments);
998     for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
999       for (unsigned J = 0; J < MidVS.NumFragments; ++J)
1000         ConcatOps[J] = Op0[SrcI++];
1001       Value *V = concatenate(Builder, ConcatOps, MidVS,
1002                              BCI.getName() + ".i" + Twine(I));
1003       Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I),
1004                                      BCI.getName() + ".i" + Twine(I));
1005     }
1006   } else {
1007     return false;
1008   }
1009 
1010   gather(&BCI, Res, *DstVS);
1011   return true;
1012 }
1013 
1014 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1015   std::optional<VectorSplit> VS = getVectorSplit(IEI.getType());
1016   if (!VS)
1017     return false;
1018 
1019   IRBuilder<> Builder(&IEI);
1020   Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS);
1021   Value *NewElt = IEI.getOperand(1);
1022   Value *InsIdx = IEI.getOperand(2);
1023 
1024   ValueVector Res;
1025   Res.resize(VS->NumFragments);
1026 
1027   if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
1028     unsigned Idx = CI->getZExtValue();
1029     unsigned Fragment = Idx / VS->NumPacked;
1030     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1031       if (I == Fragment) {
1032         bool IsPacked = VS->NumPacked > 1;
1033         if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1034             !VS->RemainderTy->isVectorTy())
1035           IsPacked = false;
1036         if (IsPacked) {
1037           Res[I] =
1038               Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked);
1039         } else {
1040           Res[I] = NewElt;
1041         }
1042       } else {
1043         Res[I] = Op0[I];
1044       }
1045     }
1046   } else {
1047     // Never split a variable insertelement that isn't fully scalarized.
1048     if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1049       return false;
1050 
1051     for (unsigned I = 0; I < VS->NumFragments; ++I) {
1052       Value *ShouldReplace =
1053           Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
1054                                InsIdx->getName() + ".is." + Twine(I));
1055       Value *OldElt = Op0[I];
1056       Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
1057                                     IEI.getName() + ".i" + Twine(I));
1058     }
1059   }
1060 
1061   gather(&IEI, Res, *VS);
1062   return true;
1063 }
1064 
1065 bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1066   Value *Op = EVI.getOperand(0);
1067   Type *OpTy = Op->getType();
1068   ValueVector Res;
1069   if (!isStructOfMatchingFixedVectors(OpTy))
1070     return false;
1071   if (CallInst *CI = dyn_cast<CallInst>(Op)) {
1072     Function *F = CI->getCalledFunction();
1073     if (!F)
1074       return false;
1075     Intrinsic::ID ID = F->getIntrinsicID();
1076     if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
1077       return false;
1078     // Note: Fall through means Operand is a`CallInst` and it is defined in
1079     // `isTriviallyScalarizable`.
1080   } else
1081     return false;
1082   Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0));
1083   std::optional<VectorSplit> VS = getVectorSplit(VecType);
1084   if (!VS)
1085     return false;
1086   IRBuilder<> Builder(&EVI);
1087   Scatterer Op0 = scatter(&EVI, Op, *VS);
1088   assert(!EVI.getIndices().empty() && "Make sure an index exists");
1089   // Note for our use case we only care about the top level index.
1090   unsigned Index = EVI.getIndices()[0];
1091   for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
1092     Value *ResElem = Builder.CreateExtractValue(
1093         Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index));
1094     Res.push_back(ResElem);
1095   }
1096 
1097   gather(&EVI, Res, *VS);
1098   return true;
1099 }
1100 
1101 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
1102   std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType());
1103   if (!VS)
1104     return false;
1105 
1106   IRBuilder<> Builder(&EEI);
1107   Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS);
1108   Value *ExtIdx = EEI.getOperand(1);
1109 
1110   if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
1111     unsigned Idx = CI->getZExtValue();
1112     unsigned Fragment = Idx / VS->NumPacked;
1113     Value *Res = Op0[Fragment];
1114     bool IsPacked = VS->NumPacked > 1;
1115     if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1116         !VS->RemainderTy->isVectorTy())
1117       IsPacked = false;
1118     if (IsPacked)
1119       Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked);
1120     replaceUses(&EEI, Res);
1121     return true;
1122   }
1123 
1124   // Never split a variable extractelement that isn't fully scalarized.
1125   if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1126     return false;
1127 
1128   Value *Res = PoisonValue::get(VS->VecTy->getElementType());
1129   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1130     Value *ShouldExtract =
1131         Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
1132                              ExtIdx->getName() + ".is." + Twine(I));
1133     Value *Elt = Op0[I];
1134     Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
1135                                EEI.getName() + ".upto" + Twine(I));
1136   }
1137   replaceUses(&EEI, Res);
1138   return true;
1139 }
1140 
1141 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
1142   std::optional<VectorSplit> VS = getVectorSplit(SVI.getType());
1143   std::optional<VectorSplit> VSOp =
1144       getVectorSplit(SVI.getOperand(0)->getType());
1145   if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
1146     return false;
1147 
1148   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp);
1149   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp);
1150   ValueVector Res;
1151   Res.resize(VS->NumFragments);
1152 
1153   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1154     int Selector = SVI.getMaskValue(I);
1155     if (Selector < 0)
1156       Res[I] = PoisonValue::get(VS->VecTy->getElementType());
1157     else if (unsigned(Selector) < Op0.size())
1158       Res[I] = Op0[Selector];
1159     else
1160       Res[I] = Op1[Selector - Op0.size()];
1161   }
1162   gather(&SVI, Res, *VS);
1163   return true;
1164 }
1165 
1166 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
1167   std::optional<VectorSplit> VS = getVectorSplit(PHI.getType());
1168   if (!VS)
1169     return false;
1170 
1171   IRBuilder<> Builder(&PHI);
1172   ValueVector Res;
1173   Res.resize(VS->NumFragments);
1174 
1175   unsigned NumOps = PHI.getNumOperands();
1176   for (unsigned I = 0; I < VS->NumFragments; ++I) {
1177     Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps,
1178                                PHI.getName() + ".i" + Twine(I));
1179   }
1180 
1181   for (unsigned I = 0; I < NumOps; ++I) {
1182     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS);
1183     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
1184     for (unsigned J = 0; J < VS->NumFragments; ++J)
1185       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
1186   }
1187   gather(&PHI, Res, *VS);
1188   return true;
1189 }
1190 
1191 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
1192   if (!ScalarizeLoadStore)
1193     return false;
1194   if (!LI.isSimple())
1195     return false;
1196 
1197   std::optional<VectorLayout> Layout = getVectorLayout(
1198       LI.getType(), LI.getAlign(), LI.getDataLayout());
1199   if (!Layout)
1200     return false;
1201 
1202   IRBuilder<> Builder(&LI);
1203   Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS);
1204   ValueVector Res;
1205   Res.resize(Layout->VS.NumFragments);
1206 
1207   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1208     Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I],
1209                                        Align(Layout->getFragmentAlign(I)),
1210                                        LI.getName() + ".i" + Twine(I));
1211   }
1212   gather(&LI, Res, Layout->VS);
1213   return true;
1214 }
1215 
1216 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
1217   if (!ScalarizeLoadStore)
1218     return false;
1219   if (!SI.isSimple())
1220     return false;
1221 
1222   Value *FullValue = SI.getValueOperand();
1223   std::optional<VectorLayout> Layout = getVectorLayout(
1224       FullValue->getType(), SI.getAlign(), SI.getDataLayout());
1225   if (!Layout)
1226     return false;
1227 
1228   IRBuilder<> Builder(&SI);
1229   Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS);
1230   Scatterer VVal = scatter(&SI, FullValue, Layout->VS);
1231 
1232   ValueVector Stores;
1233   Stores.resize(Layout->VS.NumFragments);
1234   for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1235     Value *Val = VVal[I];
1236     Value *Ptr = VPtr[I];
1237     Stores[I] =
1238         Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I));
1239   }
1240   transferMetadataAndIRFlags(&SI, Stores);
1241   return true;
1242 }
1243 
1244 bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
1245   return splitCall(CI);
1246 }
1247 
1248 bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
1249   return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
1250     return Builder.CreateFreeze(Op, Name);
1251   });
1252 }
1253 
1254 // Delete the instructions that we scalarized.  If a full vector result
1255 // is still needed, recreate it using InsertElements.
1256 bool ScalarizerVisitor::finish() {
1257   // The presence of data in Gathered or Scattered indicates changes
1258   // made to the Function.
1259   if (Gathered.empty() && Scattered.empty() && !Scalarized)
1260     return false;
1261   for (const auto &GMI : Gathered) {
1262     Instruction *Op = GMI.first;
1263     ValueVector &CV = *GMI.second;
1264     if (!Op->use_empty()) {
1265       // The value is still needed, so recreate it using a series of
1266       // insertelements and/or shufflevectors.
1267       Value *Res;
1268       if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
1269         BasicBlock *BB = Op->getParent();
1270         IRBuilder<> Builder(Op);
1271         if (isa<PHINode>(Op))
1272           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1273 
1274         VectorSplit VS = *getVectorSplit(Ty);
1275         assert(VS.NumFragments == CV.size());
1276 
1277         Res = concatenate(Builder, CV, VS, Op->getName());
1278 
1279         Res->takeName(Op);
1280       } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) {
1281         BasicBlock *BB = Op->getParent();
1282         IRBuilder<> Builder(Op);
1283         if (isa<PHINode>(Op))
1284           Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
1285 
1286         // Iterate over each element in the struct
1287         unsigned NumOfStructElements = Ty->getNumElements();
1288         SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
1289         for (unsigned I = 0; I < NumOfStructElements; ++I) {
1290           for (auto *CVelem : CV) {
1291             Value *Elem = Builder.CreateExtractValue(
1292                 CVelem, I, Op->getName() + ".elem" + Twine(I));
1293             ElemCV[I].push_back(Elem);
1294           }
1295         }
1296         Res = PoisonValue::get(Ty);
1297         for (unsigned I = 0; I < NumOfStructElements; ++I) {
1298           Type *ElemTy = Ty->getElementType(I);
1299           assert(isa<FixedVectorType>(ElemTy) &&
1300                  "Only Structs of all FixedVectorType supported");
1301           VectorSplit VS = *getVectorSplit(ElemTy);
1302           assert(VS.NumFragments == CV.size());
1303 
1304           Value *ConcatenatedVector =
1305               concatenate(Builder, ElemCV[I], VS, Op->getName());
1306           Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I,
1307                                           Op->getName() + ".insert");
1308         }
1309       } else {
1310         assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1311         Res = CV[0];
1312         if (Op == Res)
1313           continue;
1314       }
1315       Op->replaceAllUsesWith(Res);
1316     }
1317     PotentiallyDeadInstrs.emplace_back(Op);
1318   }
1319   Gathered.clear();
1320   Scattered.clear();
1321   Scalarized = false;
1322 
1323   RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1324 
1325   return true;
1326 }
1327 
1328 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1329   DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1330   const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
1331   ScalarizerVisitor Impl(DT, TTI, Options);
1332   bool Changed = Impl.visit(F);
1333   PreservedAnalyses PA;
1334   PA.preserve<DominatorTreeAnalysis>();
1335   return Changed ? PA : PreservedAnalyses::all();
1336 }
1337