1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass converts vector operations into scalar operations, in order
10 // to expose optimization opportunities on the individual scalar operations.
11 // It is mainly intended for targets that do not have vector units, but it
12 // may also be useful for revectorizing code to different vector widths.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/Transforms/Scalar/Scalarizer.h"
17 #include "llvm/ADT/PostOrderIterator.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/Argument.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Dominators.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/InstVisitor.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/IR/Value.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include <cassert>
44 #include <cstdint>
45 #include <iterator>
46 #include <map>
47 #include <utility>
48
49 using namespace llvm;
50
51 #define DEBUG_TYPE "scalarizer"
52
53 static cl::opt<bool> ClScalarizeVariableInsertExtract(
54 "scalarize-variable-insert-extract", cl::init(true), cl::Hidden,
55 cl::desc("Allow the scalarizer pass to scalarize "
56 "insertelement/extractelement with variable index"));
57
58 // This is disabled by default because having separate loads and stores
59 // makes it more likely that the -combiner-alias-analysis limits will be
60 // reached.
61 static cl::opt<bool> ClScalarizeLoadStore(
62 "scalarize-load-store", cl::init(false), cl::Hidden,
63 cl::desc("Allow the scalarizer pass to scalarize loads and store"));
64
65 namespace {
66
skipPastPhiNodesAndDbg(BasicBlock::iterator Itr)67 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
68 BasicBlock *BB = Itr->getParent();
69 if (isa<PHINode>(Itr))
70 Itr = BB->getFirstInsertionPt();
71 if (Itr != BB->end())
72 Itr = skipDebugIntrinsics(Itr);
73 return Itr;
74 }
75
76 // Used to store the scattered form of a vector.
77 using ValueVector = SmallVector<Value *, 8>;
78
79 // Used to map a vector Value and associated type to its scattered form.
80 // The associated type is only non-null for pointer values that are "scattered"
81 // when used as pointer operands to load or store.
82 //
83 // We use std::map because we want iterators to persist across insertion and
84 // because the values are relatively large.
85 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
86
87 // Lists Instructions that have been replaced with scalar implementations,
88 // along with a pointer to their scattered forms.
89 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
90
91 // Provides a very limited vector-like interface for lazily accessing one
92 // component of a scattered vector or vector pointer.
93 class Scatterer {
94 public:
95 Scatterer() = default;
96
97 // Scatter V into Size components. If new instructions are needed,
98 // insert them before BBI in BB. If Cache is nonnull, use it to cache
99 // the results.
100 Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy,
101 ValueVector *cachePtr = nullptr);
102
103 // Return component I, creating a new Value for it if necessary.
104 Value *operator[](unsigned I);
105
106 // Return the number of components.
size() const107 unsigned size() const { return Size; }
108
109 private:
110 BasicBlock *BB;
111 BasicBlock::iterator BBI;
112 Value *V;
113 Type *PtrElemTy;
114 ValueVector *CachePtr;
115 ValueVector Tmp;
116 unsigned Size;
117 };
118
119 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
120 // called Name that compares X and Y in the same way as FCI.
121 struct FCmpSplitter {
FCmpSplitter__anon7a45dada0111::FCmpSplitter122 FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
123
operator ()__anon7a45dada0111::FCmpSplitter124 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
125 const Twine &Name) const {
126 return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
127 }
128
129 FCmpInst &FCI;
130 };
131
132 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
133 // called Name that compares X and Y in the same way as ICI.
134 struct ICmpSplitter {
ICmpSplitter__anon7a45dada0111::ICmpSplitter135 ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
136
operator ()__anon7a45dada0111::ICmpSplitter137 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
138 const Twine &Name) const {
139 return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
140 }
141
142 ICmpInst &ICI;
143 };
144
145 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create
146 // a unary operator like UO called Name with operand X.
147 struct UnarySplitter {
UnarySplitter__anon7a45dada0111::UnarySplitter148 UnarySplitter(UnaryOperator &uo) : UO(uo) {}
149
operator ()__anon7a45dada0111::UnarySplitter150 Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
151 return Builder.CreateUnOp(UO.getOpcode(), Op, Name);
152 }
153
154 UnaryOperator &UO;
155 };
156
157 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
158 // a binary operator like BO called Name with operands X and Y.
159 struct BinarySplitter {
BinarySplitter__anon7a45dada0111::BinarySplitter160 BinarySplitter(BinaryOperator &bo) : BO(bo) {}
161
operator ()__anon7a45dada0111::BinarySplitter162 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
163 const Twine &Name) const {
164 return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
165 }
166
167 BinaryOperator &BO;
168 };
169
170 // Information about a load or store that we're scalarizing.
171 struct VectorLayout {
172 VectorLayout() = default;
173
174 // Return the alignment of element I.
getElemAlign__anon7a45dada0111::VectorLayout175 Align getElemAlign(unsigned I) {
176 return commonAlignment(VecAlign, I * ElemSize);
177 }
178
179 // The type of the vector.
180 FixedVectorType *VecTy = nullptr;
181
182 // The type of each element.
183 Type *ElemTy = nullptr;
184
185 // The alignment of the vector.
186 Align VecAlign;
187
188 // The size of each element.
189 uint64_t ElemSize = 0;
190 };
191
192 template <typename T>
getWithDefaultOverride(const cl::opt<T> & ClOption,const std::optional<T> & DefaultOverride)193 T getWithDefaultOverride(const cl::opt<T> &ClOption,
194 const std::optional<T> &DefaultOverride) {
195 return ClOption.getNumOccurrences() ? ClOption
196 : DefaultOverride.value_or(ClOption);
197 }
198
199 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
200 public:
ScalarizerVisitor(unsigned ParallelLoopAccessMDKind,DominatorTree * DT,ScalarizerPassOptions Options)201 ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT,
202 ScalarizerPassOptions Options)
203 : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT),
204 ScalarizeVariableInsertExtract(
205 getWithDefaultOverride(ClScalarizeVariableInsertExtract,
206 Options.ScalarizeVariableInsertExtract)),
207 ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore,
208 Options.ScalarizeLoadStore)) {
209 }
210
211 bool visit(Function &F);
212
213 // InstVisitor methods. They return true if the instruction was scalarized,
214 // false if nothing changed.
visitInstruction(Instruction & I)215 bool visitInstruction(Instruction &I) { return false; }
216 bool visitSelectInst(SelectInst &SI);
217 bool visitICmpInst(ICmpInst &ICI);
218 bool visitFCmpInst(FCmpInst &FCI);
219 bool visitUnaryOperator(UnaryOperator &UO);
220 bool visitBinaryOperator(BinaryOperator &BO);
221 bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
222 bool visitCastInst(CastInst &CI);
223 bool visitBitCastInst(BitCastInst &BCI);
224 bool visitInsertElementInst(InsertElementInst &IEI);
225 bool visitExtractElementInst(ExtractElementInst &EEI);
226 bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
227 bool visitPHINode(PHINode &PHI);
228 bool visitLoadInst(LoadInst &LI);
229 bool visitStoreInst(StoreInst &SI);
230 bool visitCallInst(CallInst &ICI);
231
232 private:
233 Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr);
234 void gather(Instruction *Op, const ValueVector &CV);
235 void replaceUses(Instruction *Op, Value *CV);
236 bool canTransferMetadata(unsigned Kind);
237 void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
238 std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
239 const DataLayout &DL);
240 bool finish();
241
242 template<typename T> bool splitUnary(Instruction &, const T &);
243 template<typename T> bool splitBinary(Instruction &, const T &);
244
245 bool splitCall(CallInst &CI);
246
247 ScatterMap Scattered;
248 GatherList Gathered;
249 bool Scalarized;
250
251 SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
252
253 unsigned ParallelLoopAccessMDKind;
254
255 DominatorTree *DT;
256
257 const bool ScalarizeVariableInsertExtract;
258 const bool ScalarizeLoadStore;
259 };
260
261 class ScalarizerLegacyPass : public FunctionPass {
262 public:
263 static char ID;
264
ScalarizerLegacyPass()265 ScalarizerLegacyPass() : FunctionPass(ID) {
266 initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
267 }
268
269 bool runOnFunction(Function &F) override;
270
getAnalysisUsage(AnalysisUsage & AU) const271 void getAnalysisUsage(AnalysisUsage& AU) const override {
272 AU.addRequired<DominatorTreeWrapperPass>();
273 AU.addPreserved<DominatorTreeWrapperPass>();
274 }
275 };
276
277 } // end anonymous namespace
278
279 char ScalarizerLegacyPass::ID = 0;
280 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
281 "Scalarize vector operations", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)282 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
283 INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
284 "Scalarize vector operations", false, false)
285
286 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
287 Type *PtrElemTy, ValueVector *cachePtr)
288 : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) {
289 Type *Ty = V->getType();
290 if (Ty->isPointerTy()) {
291 assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) &&
292 "Pointer element type mismatch");
293 Ty = PtrElemTy;
294 }
295 Size = cast<FixedVectorType>(Ty)->getNumElements();
296 if (!CachePtr)
297 Tmp.resize(Size, nullptr);
298 else if (CachePtr->empty())
299 CachePtr->resize(Size, nullptr);
300 else
301 assert(Size == CachePtr->size() && "Inconsistent vector sizes");
302 }
303
304 // Return component I, creating a new Value for it if necessary.
operator [](unsigned I)305 Value *Scatterer::operator[](unsigned I) {
306 ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
307 // Try to reuse a previous value.
308 if (CV[I])
309 return CV[I];
310 IRBuilder<> Builder(BB, BBI);
311 if (PtrElemTy) {
312 Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType();
313 if (!CV[0]) {
314 Type *NewPtrTy = PointerType::get(
315 VectorElemTy, V->getType()->getPointerAddressSpace());
316 CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0");
317 }
318 if (I != 0)
319 CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I,
320 V->getName() + ".i" + Twine(I));
321 } else {
322 // Search through a chain of InsertElementInsts looking for element I.
323 // Record other elements in the cache. The new V is still suitable
324 // for all uncached indices.
325 while (true) {
326 InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
327 if (!Insert)
328 break;
329 ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
330 if (!Idx)
331 break;
332 unsigned J = Idx->getZExtValue();
333 V = Insert->getOperand(0);
334 if (I == J) {
335 CV[J] = Insert->getOperand(1);
336 return CV[J];
337 } else if (!CV[J]) {
338 // Only cache the first entry we find for each index we're not actively
339 // searching for. This prevents us from going too far up the chain and
340 // caching incorrect entries.
341 CV[J] = Insert->getOperand(1);
342 }
343 }
344 CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
345 V->getName() + ".i" + Twine(I));
346 }
347 return CV[I];
348 }
349
runOnFunction(Function & F)350 bool ScalarizerLegacyPass::runOnFunction(Function &F) {
351 if (skipFunction(F))
352 return false;
353
354 Module &M = *F.getParent();
355 unsigned ParallelLoopAccessMDKind =
356 M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
357 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
358 ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions());
359 return Impl.visit(F);
360 }
361
createScalarizerPass()362 FunctionPass *llvm::createScalarizerPass() {
363 return new ScalarizerLegacyPass();
364 }
365
visit(Function & F)366 bool ScalarizerVisitor::visit(Function &F) {
367 assert(Gathered.empty() && Scattered.empty());
368
369 Scalarized = false;
370
371 // To ensure we replace gathered components correctly we need to do an ordered
372 // traversal of the basic blocks in the function.
373 ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
374 for (BasicBlock *BB : RPOT) {
375 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
376 Instruction *I = &*II;
377 bool Done = InstVisitor::visit(I);
378 ++II;
379 if (Done && I->getType()->isVoidTy())
380 I->eraseFromParent();
381 }
382 }
383 return finish();
384 }
385
386 // Return a scattered form of V that can be accessed by Point. V must be a
387 // vector or a pointer to a vector.
scatter(Instruction * Point,Value * V,Type * PtrElemTy)388 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
389 Type *PtrElemTy) {
390 if (Argument *VArg = dyn_cast<Argument>(V)) {
391 // Put the scattered form of arguments in the entry block,
392 // so that it can be used everywhere.
393 Function *F = VArg->getParent();
394 BasicBlock *BB = &F->getEntryBlock();
395 return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[{V, PtrElemTy}]);
396 }
397 if (Instruction *VOp = dyn_cast<Instruction>(V)) {
398 // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
399 // nodes in predecessors. If those predecessors are unreachable from entry,
400 // then the IR in those blocks could have unexpected properties resulting in
401 // infinite loops in Scatterer::operator[]. By simply treating values
402 // originating from instructions in unreachable blocks as undef we do not
403 // need to analyse them further.
404 if (!DT->isReachableFromEntry(VOp->getParent()))
405 return Scatterer(Point->getParent(), Point->getIterator(),
406 PoisonValue::get(V->getType()), PtrElemTy);
407 // Put the scattered form of an instruction directly after the
408 // instruction, skipping over PHI nodes and debug intrinsics.
409 BasicBlock *BB = VOp->getParent();
410 return Scatterer(
411 BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V,
412 PtrElemTy, &Scattered[{V, PtrElemTy}]);
413 }
414 // In the fallback case, just put the scattered before Point and
415 // keep the result local to Point.
416 return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy);
417 }
418
419 // Replace Op with the gathered form of the components in CV. Defer the
420 // deletion of Op and creation of the gathered form to the end of the pass,
421 // so that we can avoid creating the gathered form if all uses of Op are
422 // replaced with uses of CV.
gather(Instruction * Op,const ValueVector & CV)423 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {
424 transferMetadataAndIRFlags(Op, CV);
425
426 // If we already have a scattered form of Op (created from ExtractElements
427 // of Op itself), replace them with the new form.
428 ValueVector &SV = Scattered[{Op, nullptr}];
429 if (!SV.empty()) {
430 for (unsigned I = 0, E = SV.size(); I != E; ++I) {
431 Value *V = SV[I];
432 if (V == nullptr || SV[I] == CV[I])
433 continue;
434
435 Instruction *Old = cast<Instruction>(V);
436 if (isa<Instruction>(CV[I]))
437 CV[I]->takeName(Old);
438 Old->replaceAllUsesWith(CV[I]);
439 PotentiallyDeadInstrs.emplace_back(Old);
440 }
441 }
442 SV = CV;
443 Gathered.push_back(GatherList::value_type(Op, &SV));
444 }
445
446 // Replace Op with CV and collect Op has a potentially dead instruction.
replaceUses(Instruction * Op,Value * CV)447 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
448 if (CV != Op) {
449 Op->replaceAllUsesWith(CV);
450 PotentiallyDeadInstrs.emplace_back(Op);
451 Scalarized = true;
452 }
453 }
454
455 // Return true if it is safe to transfer the given metadata tag from
456 // vector to scalar instructions.
canTransferMetadata(unsigned Tag)457 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
458 return (Tag == LLVMContext::MD_tbaa
459 || Tag == LLVMContext::MD_fpmath
460 || Tag == LLVMContext::MD_tbaa_struct
461 || Tag == LLVMContext::MD_invariant_load
462 || Tag == LLVMContext::MD_alias_scope
463 || Tag == LLVMContext::MD_noalias
464 || Tag == ParallelLoopAccessMDKind
465 || Tag == LLVMContext::MD_access_group);
466 }
467
468 // Transfer metadata from Op to the instructions in CV if it is known
469 // to be safe to do so.
transferMetadataAndIRFlags(Instruction * Op,const ValueVector & CV)470 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
471 const ValueVector &CV) {
472 SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
473 Op->getAllMetadataOtherThanDebugLoc(MDs);
474 for (unsigned I = 0, E = CV.size(); I != E; ++I) {
475 if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
476 for (const auto &MD : MDs)
477 if (canTransferMetadata(MD.first))
478 New->setMetadata(MD.first, MD.second);
479 New->copyIRFlags(Op);
480 if (Op->getDebugLoc() && !New->getDebugLoc())
481 New->setDebugLoc(Op->getDebugLoc());
482 }
483 }
484 }
485
486 // Try to fill in Layout from Ty, returning true on success. Alignment is
487 // the alignment of the vector, or std::nullopt if the ABI default should be
488 // used.
489 std::optional<VectorLayout>
getVectorLayout(Type * Ty,Align Alignment,const DataLayout & DL)490 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
491 const DataLayout &DL) {
492 VectorLayout Layout;
493 // Make sure we're dealing with a vector.
494 Layout.VecTy = dyn_cast<FixedVectorType>(Ty);
495 if (!Layout.VecTy)
496 return std::nullopt;
497 // Check that we're dealing with full-byte elements.
498 Layout.ElemTy = Layout.VecTy->getElementType();
499 if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy))
500 return std::nullopt;
501 Layout.VecAlign = Alignment;
502 Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
503 return Layout;
504 }
505
506 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
507 // to create an instruction like I with operand X and name Name.
508 template<typename Splitter>
splitUnary(Instruction & I,const Splitter & Split)509 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
510 auto *VT = dyn_cast<FixedVectorType>(I.getType());
511 if (!VT)
512 return false;
513
514 unsigned NumElems = VT->getNumElements();
515 IRBuilder<> Builder(&I);
516 Scatterer Op = scatter(&I, I.getOperand(0));
517 assert(Op.size() == NumElems && "Mismatched unary operation");
518 ValueVector Res;
519 Res.resize(NumElems);
520 for (unsigned Elem = 0; Elem < NumElems; ++Elem)
521 Res[Elem] = Split(Builder, Op[Elem], I.getName() + ".i" + Twine(Elem));
522 gather(&I, Res);
523 return true;
524 }
525
526 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
527 // to create an instruction like I with operands X and Y and name Name.
528 template<typename Splitter>
splitBinary(Instruction & I,const Splitter & Split)529 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
530 auto *VT = dyn_cast<FixedVectorType>(I.getType());
531 if (!VT)
532 return false;
533
534 unsigned NumElems = VT->getNumElements();
535 IRBuilder<> Builder(&I);
536 Scatterer VOp0 = scatter(&I, I.getOperand(0));
537 Scatterer VOp1 = scatter(&I, I.getOperand(1));
538 assert(VOp0.size() == NumElems && "Mismatched binary operation");
539 assert(VOp1.size() == NumElems && "Mismatched binary operation");
540 ValueVector Res;
541 Res.resize(NumElems);
542 for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
543 Value *Op0 = VOp0[Elem];
544 Value *Op1 = VOp1[Elem];
545 Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem));
546 }
547 gather(&I, Res);
548 return true;
549 }
550
isTriviallyScalariable(Intrinsic::ID ID)551 static bool isTriviallyScalariable(Intrinsic::ID ID) {
552 return isTriviallyVectorizable(ID);
553 }
554
555 // All of the current scalarizable intrinsics only have one mangled type.
getScalarIntrinsicDeclaration(Module * M,Intrinsic::ID ID,ArrayRef<Type * > Tys)556 static Function *getScalarIntrinsicDeclaration(Module *M,
557 Intrinsic::ID ID,
558 ArrayRef<Type*> Tys) {
559 return Intrinsic::getDeclaration(M, ID, Tys);
560 }
561
562 /// If a call to a vector typed intrinsic function, split into a scalar call per
563 /// element if possible for the intrinsic.
splitCall(CallInst & CI)564 bool ScalarizerVisitor::splitCall(CallInst &CI) {
565 auto *VT = dyn_cast<FixedVectorType>(CI.getType());
566 if (!VT)
567 return false;
568
569 Function *F = CI.getCalledFunction();
570 if (!F)
571 return false;
572
573 Intrinsic::ID ID = F->getIntrinsicID();
574 if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))
575 return false;
576
577 unsigned NumElems = VT->getNumElements();
578 unsigned NumArgs = CI.arg_size();
579
580 ValueVector ScalarOperands(NumArgs);
581 SmallVector<Scatterer, 8> Scattered(NumArgs);
582
583 Scattered.resize(NumArgs);
584
585 SmallVector<llvm::Type *, 3> Tys;
586 Tys.push_back(VT->getScalarType());
587
588 // Assumes that any vector type has the same number of elements as the return
589 // vector type, which is true for all current intrinsics.
590 for (unsigned I = 0; I != NumArgs; ++I) {
591 Value *OpI = CI.getOperand(I);
592 if (OpI->getType()->isVectorTy()) {
593 Scattered[I] = scatter(&CI, OpI);
594 assert(Scattered[I].size() == NumElems && "mismatched call operands");
595 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
596 Tys.push_back(OpI->getType()->getScalarType());
597 } else {
598 ScalarOperands[I] = OpI;
599 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
600 Tys.push_back(OpI->getType());
601 }
602 }
603
604 ValueVector Res(NumElems);
605 ValueVector ScalarCallOps(NumArgs);
606
607 Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, Tys);
608 IRBuilder<> Builder(&CI);
609
610 // Perform actual scalarization, taking care to preserve any scalar operands.
611 for (unsigned Elem = 0; Elem < NumElems; ++Elem) {
612 ScalarCallOps.clear();
613
614 for (unsigned J = 0; J != NumArgs; ++J) {
615 if (isVectorIntrinsicWithScalarOpAtArg(ID, J))
616 ScalarCallOps.push_back(ScalarOperands[J]);
617 else
618 ScalarCallOps.push_back(Scattered[J][Elem]);
619 }
620
621 Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps,
622 CI.getName() + ".i" + Twine(Elem));
623 }
624
625 gather(&CI, Res);
626 return true;
627 }
628
visitSelectInst(SelectInst & SI)629 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
630 auto *VT = dyn_cast<FixedVectorType>(SI.getType());
631 if (!VT)
632 return false;
633
634 unsigned NumElems = VT->getNumElements();
635 IRBuilder<> Builder(&SI);
636 Scatterer VOp1 = scatter(&SI, SI.getOperand(1));
637 Scatterer VOp2 = scatter(&SI, SI.getOperand(2));
638 assert(VOp1.size() == NumElems && "Mismatched select");
639 assert(VOp2.size() == NumElems && "Mismatched select");
640 ValueVector Res;
641 Res.resize(NumElems);
642
643 if (SI.getOperand(0)->getType()->isVectorTy()) {
644 Scatterer VOp0 = scatter(&SI, SI.getOperand(0));
645 assert(VOp0.size() == NumElems && "Mismatched select");
646 for (unsigned I = 0; I < NumElems; ++I) {
647 Value *Op0 = VOp0[I];
648 Value *Op1 = VOp1[I];
649 Value *Op2 = VOp2[I];
650 Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
651 SI.getName() + ".i" + Twine(I));
652 }
653 } else {
654 Value *Op0 = SI.getOperand(0);
655 for (unsigned I = 0; I < NumElems; ++I) {
656 Value *Op1 = VOp1[I];
657 Value *Op2 = VOp2[I];
658 Res[I] = Builder.CreateSelect(Op0, Op1, Op2,
659 SI.getName() + ".i" + Twine(I));
660 }
661 }
662 gather(&SI, Res);
663 return true;
664 }
665
visitICmpInst(ICmpInst & ICI)666 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
667 return splitBinary(ICI, ICmpSplitter(ICI));
668 }
669
visitFCmpInst(FCmpInst & FCI)670 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
671 return splitBinary(FCI, FCmpSplitter(FCI));
672 }
673
visitUnaryOperator(UnaryOperator & UO)674 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
675 return splitUnary(UO, UnarySplitter(UO));
676 }
677
visitBinaryOperator(BinaryOperator & BO)678 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
679 return splitBinary(BO, BinarySplitter(BO));
680 }
681
visitGetElementPtrInst(GetElementPtrInst & GEPI)682 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
683 auto *VT = dyn_cast<FixedVectorType>(GEPI.getType());
684 if (!VT)
685 return false;
686
687 IRBuilder<> Builder(&GEPI);
688 unsigned NumElems = VT->getNumElements();
689 unsigned NumIndices = GEPI.getNumIndices();
690
691 // The base pointer might be scalar even if it's a vector GEP. In those cases,
692 // splat the pointer into a vector value, and scatter that vector.
693 Value *Op0 = GEPI.getOperand(0);
694 if (!Op0->getType()->isVectorTy())
695 Op0 = Builder.CreateVectorSplat(NumElems, Op0);
696 Scatterer Base = scatter(&GEPI, Op0);
697
698 SmallVector<Scatterer, 8> Ops;
699 Ops.resize(NumIndices);
700 for (unsigned I = 0; I < NumIndices; ++I) {
701 Value *Op = GEPI.getOperand(I + 1);
702
703 // The indices might be scalars even if it's a vector GEP. In those cases,
704 // splat the scalar into a vector value, and scatter that vector.
705 if (!Op->getType()->isVectorTy())
706 Op = Builder.CreateVectorSplat(NumElems, Op);
707
708 Ops[I] = scatter(&GEPI, Op);
709 }
710
711 ValueVector Res;
712 Res.resize(NumElems);
713 for (unsigned I = 0; I < NumElems; ++I) {
714 SmallVector<Value *, 8> Indices;
715 Indices.resize(NumIndices);
716 for (unsigned J = 0; J < NumIndices; ++J)
717 Indices[J] = Ops[J][I];
718 Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices,
719 GEPI.getName() + ".i" + Twine(I));
720 if (GEPI.isInBounds())
721 if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
722 NewGEPI->setIsInBounds();
723 }
724 gather(&GEPI, Res);
725 return true;
726 }
727
visitCastInst(CastInst & CI)728 bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
729 auto *VT = dyn_cast<FixedVectorType>(CI.getDestTy());
730 if (!VT)
731 return false;
732
733 unsigned NumElems = VT->getNumElements();
734 IRBuilder<> Builder(&CI);
735 Scatterer Op0 = scatter(&CI, CI.getOperand(0));
736 assert(Op0.size() == NumElems && "Mismatched cast");
737 ValueVector Res;
738 Res.resize(NumElems);
739 for (unsigned I = 0; I < NumElems; ++I)
740 Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
741 CI.getName() + ".i" + Twine(I));
742 gather(&CI, Res);
743 return true;
744 }
745
visitBitCastInst(BitCastInst & BCI)746 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
747 auto *DstVT = dyn_cast<FixedVectorType>(BCI.getDestTy());
748 auto *SrcVT = dyn_cast<FixedVectorType>(BCI.getSrcTy());
749 if (!DstVT || !SrcVT)
750 return false;
751
752 unsigned DstNumElems = DstVT->getNumElements();
753 unsigned SrcNumElems = SrcVT->getNumElements();
754 IRBuilder<> Builder(&BCI);
755 Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
756 ValueVector Res;
757 Res.resize(DstNumElems);
758
759 if (DstNumElems == SrcNumElems) {
760 for (unsigned I = 0; I < DstNumElems; ++I)
761 Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
762 BCI.getName() + ".i" + Twine(I));
763 } else if (DstNumElems > SrcNumElems) {
764 // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
765 // individual elements to the destination.
766 unsigned FanOut = DstNumElems / SrcNumElems;
767 auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut);
768 unsigned ResI = 0;
769 for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
770 Value *V = Op0[Op0I];
771 Instruction *VI;
772 // Look through any existing bitcasts before converting to <N x t2>.
773 // In the best case, the resulting conversion might be a no-op.
774 while ((VI = dyn_cast<Instruction>(V)) &&
775 VI->getOpcode() == Instruction::BitCast)
776 V = VI->getOperand(0);
777 V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
778 Scatterer Mid = scatter(&BCI, V);
779 for (unsigned MidI = 0; MidI < FanOut; ++MidI)
780 Res[ResI++] = Mid[MidI];
781 }
782 } else {
783 // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
784 unsigned FanIn = SrcNumElems / DstNumElems;
785 auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn);
786 unsigned Op0I = 0;
787 for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
788 Value *V = PoisonValue::get(MidTy);
789 for (unsigned MidI = 0; MidI < FanIn; ++MidI)
790 V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
791 BCI.getName() + ".i" + Twine(ResI)
792 + ".upto" + Twine(MidI));
793 Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
794 BCI.getName() + ".i" + Twine(ResI));
795 }
796 }
797 gather(&BCI, Res);
798 return true;
799 }
800
visitInsertElementInst(InsertElementInst & IEI)801 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
802 auto *VT = dyn_cast<FixedVectorType>(IEI.getType());
803 if (!VT)
804 return false;
805
806 unsigned NumElems = VT->getNumElements();
807 IRBuilder<> Builder(&IEI);
808 Scatterer Op0 = scatter(&IEI, IEI.getOperand(0));
809 Value *NewElt = IEI.getOperand(1);
810 Value *InsIdx = IEI.getOperand(2);
811
812 ValueVector Res;
813 Res.resize(NumElems);
814
815 if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) {
816 for (unsigned I = 0; I < NumElems; ++I)
817 Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I];
818 } else {
819 if (!ScalarizeVariableInsertExtract)
820 return false;
821
822 for (unsigned I = 0; I < NumElems; ++I) {
823 Value *ShouldReplace =
824 Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I),
825 InsIdx->getName() + ".is." + Twine(I));
826 Value *OldElt = Op0[I];
827 Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt,
828 IEI.getName() + ".i" + Twine(I));
829 }
830 }
831
832 gather(&IEI, Res);
833 return true;
834 }
835
visitExtractElementInst(ExtractElementInst & EEI)836 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
837 auto *VT = dyn_cast<FixedVectorType>(EEI.getOperand(0)->getType());
838 if (!VT)
839 return false;
840
841 unsigned NumSrcElems = VT->getNumElements();
842 IRBuilder<> Builder(&EEI);
843 Scatterer Op0 = scatter(&EEI, EEI.getOperand(0));
844 Value *ExtIdx = EEI.getOperand(1);
845
846 if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) {
847 Value *Res = Op0[CI->getValue().getZExtValue()];
848 replaceUses(&EEI, Res);
849 return true;
850 }
851
852 if (!ScalarizeVariableInsertExtract)
853 return false;
854
855 Value *Res = PoisonValue::get(VT->getElementType());
856 for (unsigned I = 0; I < NumSrcElems; ++I) {
857 Value *ShouldExtract =
858 Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I),
859 ExtIdx->getName() + ".is." + Twine(I));
860 Value *Elt = Op0[I];
861 Res = Builder.CreateSelect(ShouldExtract, Elt, Res,
862 EEI.getName() + ".upto" + Twine(I));
863 }
864 replaceUses(&EEI, Res);
865 return true;
866 }
867
visitShuffleVectorInst(ShuffleVectorInst & SVI)868 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
869 auto *VT = dyn_cast<FixedVectorType>(SVI.getType());
870 if (!VT)
871 return false;
872
873 unsigned NumElems = VT->getNumElements();
874 Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
875 Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
876 ValueVector Res;
877 Res.resize(NumElems);
878
879 for (unsigned I = 0; I < NumElems; ++I) {
880 int Selector = SVI.getMaskValue(I);
881 if (Selector < 0)
882 Res[I] = UndefValue::get(VT->getElementType());
883 else if (unsigned(Selector) < Op0.size())
884 Res[I] = Op0[Selector];
885 else
886 Res[I] = Op1[Selector - Op0.size()];
887 }
888 gather(&SVI, Res);
889 return true;
890 }
891
visitPHINode(PHINode & PHI)892 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
893 auto *VT = dyn_cast<FixedVectorType>(PHI.getType());
894 if (!VT)
895 return false;
896
897 unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();
898 IRBuilder<> Builder(&PHI);
899 ValueVector Res;
900 Res.resize(NumElems);
901
902 unsigned NumOps = PHI.getNumOperands();
903 for (unsigned I = 0; I < NumElems; ++I)
904 Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
905 PHI.getName() + ".i" + Twine(I));
906
907 for (unsigned I = 0; I < NumOps; ++I) {
908 Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
909 BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
910 for (unsigned J = 0; J < NumElems; ++J)
911 cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
912 }
913 gather(&PHI, Res);
914 return true;
915 }
916
visitLoadInst(LoadInst & LI)917 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
918 if (!ScalarizeLoadStore)
919 return false;
920 if (!LI.isSimple())
921 return false;
922
923 std::optional<VectorLayout> Layout = getVectorLayout(
924 LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout());
925 if (!Layout)
926 return false;
927
928 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
929 IRBuilder<> Builder(&LI);
930 Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType());
931 ValueVector Res;
932 Res.resize(NumElems);
933
934 for (unsigned I = 0; I < NumElems; ++I)
935 Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I],
936 Align(Layout->getElemAlign(I)),
937 LI.getName() + ".i" + Twine(I));
938 gather(&LI, Res);
939 return true;
940 }
941
visitStoreInst(StoreInst & SI)942 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
943 if (!ScalarizeLoadStore)
944 return false;
945 if (!SI.isSimple())
946 return false;
947
948 Value *FullValue = SI.getValueOperand();
949 std::optional<VectorLayout> Layout = getVectorLayout(
950 FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout());
951 if (!Layout)
952 return false;
953
954 unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();
955 IRBuilder<> Builder(&SI);
956 Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType());
957 Scatterer VVal = scatter(&SI, FullValue);
958
959 ValueVector Stores;
960 Stores.resize(NumElems);
961 for (unsigned I = 0; I < NumElems; ++I) {
962 Value *Val = VVal[I];
963 Value *Ptr = VPtr[I];
964 Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I));
965 }
966 transferMetadataAndIRFlags(&SI, Stores);
967 return true;
968 }
969
visitCallInst(CallInst & CI)970 bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
971 return splitCall(CI);
972 }
973
974 // Delete the instructions that we scalarized. If a full vector result
975 // is still needed, recreate it using InsertElements.
finish()976 bool ScalarizerVisitor::finish() {
977 // The presence of data in Gathered or Scattered indicates changes
978 // made to the Function.
979 if (Gathered.empty() && Scattered.empty() && !Scalarized)
980 return false;
981 for (const auto &GMI : Gathered) {
982 Instruction *Op = GMI.first;
983 ValueVector &CV = *GMI.second;
984 if (!Op->use_empty()) {
985 // The value is still needed, so recreate it using a series of
986 // InsertElements.
987 Value *Res = PoisonValue::get(Op->getType());
988 if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) {
989 BasicBlock *BB = Op->getParent();
990 unsigned Count = Ty->getNumElements();
991 IRBuilder<> Builder(Op);
992 if (isa<PHINode>(Op))
993 Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
994 for (unsigned I = 0; I < Count; ++I)
995 Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
996 Op->getName() + ".upto" + Twine(I));
997 Res->takeName(Op);
998 } else {
999 assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1000 Res = CV[0];
1001 if (Op == Res)
1002 continue;
1003 }
1004 Op->replaceAllUsesWith(Res);
1005 }
1006 PotentiallyDeadInstrs.emplace_back(Op);
1007 }
1008 Gathered.clear();
1009 Scattered.clear();
1010 Scalarized = false;
1011
1012 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
1013
1014 return true;
1015 }
1016
run(Function & F,FunctionAnalysisManager & AM)1017 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1018 Module &M = *F.getParent();
1019 unsigned ParallelLoopAccessMDKind =
1020 M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
1021 DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
1022 ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options);
1023 bool Changed = Impl.visit(F);
1024 PreservedAnalyses PA;
1025 PA.preserve<DominatorTreeAnalysis>();
1026 return Changed ? PA : PreservedAnalyses::all();
1027 }
1028