1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass converts vector operations into scalar operations (or, optionally, 10 // operations on smaller vector widths), in order to expose optimization 11 // opportunities on the individual scalar operations. 12 // It is mainly intended for targets that do not have vector units, but it 13 // may also be useful for revectorizing code to different vector widths. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Scalar/Scalarizer.h" 18 #include "llvm/ADT/PostOrderIterator.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/ADT/Twine.h" 21 #include "llvm/Analysis/TargetTransformInfo.h" 22 #include "llvm/Analysis/VectorUtils.h" 23 #include "llvm/IR/Argument.h" 24 #include "llvm/IR/BasicBlock.h" 25 #include "llvm/IR/Constants.h" 26 #include "llvm/IR/DataLayout.h" 27 #include "llvm/IR/DerivedTypes.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/InstVisitor.h" 32 #include "llvm/IR/InstrTypes.h" 33 #include "llvm/IR/Instruction.h" 34 #include "llvm/IR/Instructions.h" 35 #include "llvm/IR/Intrinsics.h" 36 #include "llvm/IR/LLVMContext.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/IR/Value.h" 40 #include "llvm/InitializePasses.h" 41 #include "llvm/Support/Casting.h" 42 #include "llvm/Transforms/Utils/Local.h" 43 #include <cassert> 44 #include <cstdint> 45 #include <iterator> 46 #include <map> 47 #include <utility> 48 49 using namespace llvm; 50 51 #define DEBUG_TYPE "scalarizer" 52 53 namespace { 54 55 BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) { 56 BasicBlock *BB = Itr->getParent(); 57 if (isa<PHINode>(Itr)) 58 Itr = BB->getFirstInsertionPt(); 59 if (Itr != BB->end()) 60 Itr = skipDebugIntrinsics(Itr); 61 return Itr; 62 } 63 64 // Used to store the scattered form of a vector. 65 using ValueVector = SmallVector<Value *, 8>; 66 67 // Used to map a vector Value and associated type to its scattered form. 68 // The associated type is only non-null for pointer values that are "scattered" 69 // when used as pointer operands to load or store. 70 // 71 // We use std::map because we want iterators to persist across insertion and 72 // because the values are relatively large. 73 using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>; 74 75 // Lists Instructions that have been replaced with scalar implementations, 76 // along with a pointer to their scattered forms. 77 using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>; 78 79 struct VectorSplit { 80 // The type of the vector. 81 FixedVectorType *VecTy = nullptr; 82 83 // The number of elements packed in a fragment (other than the remainder). 84 unsigned NumPacked = 0; 85 86 // The number of fragments (scalars or smaller vectors) into which the vector 87 // shall be split. 88 unsigned NumFragments = 0; 89 90 // The type of each complete fragment. 91 Type *SplitTy = nullptr; 92 93 // The type of the remainder (last) fragment; null if all fragments are 94 // complete. 95 Type *RemainderTy = nullptr; 96 97 Type *getFragmentType(unsigned I) const { 98 return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy; 99 } 100 }; 101 102 // Provides a very limited vector-like interface for lazily accessing one 103 // component of a scattered vector or vector pointer. 104 class Scatterer { 105 public: 106 Scatterer() = default; 107 108 // Scatter V into Size components. If new instructions are needed, 109 // insert them before BBI in BB. If Cache is nonnull, use it to cache 110 // the results. 111 Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, 112 const VectorSplit &VS, ValueVector *cachePtr = nullptr); 113 114 // Return component I, creating a new Value for it if necessary. 115 Value *operator[](unsigned I); 116 117 // Return the number of components. 118 unsigned size() const { return VS.NumFragments; } 119 120 private: 121 BasicBlock *BB; 122 BasicBlock::iterator BBI; 123 Value *V; 124 VectorSplit VS; 125 bool IsPointer; 126 ValueVector *CachePtr; 127 ValueVector Tmp; 128 }; 129 130 // FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp 131 // called Name that compares X and Y in the same way as FCI. 132 struct FCmpSplitter { 133 FCmpSplitter(FCmpInst &fci) : FCI(fci) {} 134 135 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, 136 const Twine &Name) const { 137 return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name); 138 } 139 140 FCmpInst &FCI; 141 }; 142 143 // ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp 144 // called Name that compares X and Y in the same way as ICI. 145 struct ICmpSplitter { 146 ICmpSplitter(ICmpInst &ici) : ICI(ici) {} 147 148 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, 149 const Twine &Name) const { 150 return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name); 151 } 152 153 ICmpInst &ICI; 154 }; 155 156 // UnarySplitter(UO)(Builder, X, Name) uses Builder to create 157 // a unary operator like UO called Name with operand X. 158 struct UnarySplitter { 159 UnarySplitter(UnaryOperator &uo) : UO(uo) {} 160 161 Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const { 162 return Builder.CreateUnOp(UO.getOpcode(), Op, Name); 163 } 164 165 UnaryOperator &UO; 166 }; 167 168 // BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create 169 // a binary operator like BO called Name with operands X and Y. 170 struct BinarySplitter { 171 BinarySplitter(BinaryOperator &bo) : BO(bo) {} 172 173 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, 174 const Twine &Name) const { 175 return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name); 176 } 177 178 BinaryOperator &BO; 179 }; 180 181 // Information about a load or store that we're scalarizing. 182 struct VectorLayout { 183 VectorLayout() = default; 184 185 // Return the alignment of fragment Frag. 186 Align getFragmentAlign(unsigned Frag) { 187 return commonAlignment(VecAlign, Frag * SplitSize); 188 } 189 190 // The split of the underlying vector type. 191 VectorSplit VS; 192 193 // The alignment of the vector. 194 Align VecAlign; 195 196 // The size of each (non-remainder) fragment in bytes. 197 uint64_t SplitSize = 0; 198 }; 199 200 static bool isStructOfMatchingFixedVectors(Type *Ty) { 201 if (!isa<StructType>(Ty)) 202 return false; 203 unsigned StructSize = Ty->getNumContainedTypes(); 204 if (StructSize < 1) 205 return false; 206 FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(0)); 207 if (!VecTy) 208 return false; 209 unsigned VecSize = VecTy->getNumElements(); 210 for (unsigned I = 1; I < StructSize; I++) { 211 VecTy = dyn_cast<FixedVectorType>(Ty->getContainedType(I)); 212 if (!VecTy || VecSize != VecTy->getNumElements()) 213 return false; 214 } 215 return true; 216 } 217 218 /// Concatenate the given fragments to a single vector value of the type 219 /// described in @p VS. 220 static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments, 221 const VectorSplit &VS, Twine Name) { 222 unsigned NumElements = VS.VecTy->getNumElements(); 223 SmallVector<int> ExtendMask; 224 SmallVector<int> InsertMask; 225 226 if (VS.NumPacked > 1) { 227 // Prepare the shufflevector masks once and re-use them for all 228 // fragments. 229 ExtendMask.resize(NumElements, -1); 230 for (unsigned I = 0; I < VS.NumPacked; ++I) 231 ExtendMask[I] = I; 232 233 InsertMask.resize(NumElements); 234 for (unsigned I = 0; I < NumElements; ++I) 235 InsertMask[I] = I; 236 } 237 238 Value *Res = PoisonValue::get(VS.VecTy); 239 for (unsigned I = 0; I < VS.NumFragments; ++I) { 240 Value *Fragment = Fragments[I]; 241 242 unsigned NumPacked = VS.NumPacked; 243 if (I == VS.NumFragments - 1 && VS.RemainderTy) { 244 if (auto *RemVecTy = dyn_cast<FixedVectorType>(VS.RemainderTy)) 245 NumPacked = RemVecTy->getNumElements(); 246 else 247 NumPacked = 1; 248 } 249 250 if (NumPacked == 1) { 251 Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked, 252 Name + ".upto" + Twine(I)); 253 } else { 254 Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask); 255 if (I == 0) { 256 Res = Fragment; 257 } else { 258 for (unsigned J = 0; J < NumPacked; ++J) 259 InsertMask[I * VS.NumPacked + J] = NumElements + J; 260 Res = Builder.CreateShuffleVector(Res, Fragment, InsertMask, 261 Name + ".upto" + Twine(I)); 262 for (unsigned J = 0; J < NumPacked; ++J) 263 InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J; 264 } 265 } 266 } 267 268 return Res; 269 } 270 271 class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> { 272 public: 273 ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI, 274 ScalarizerPassOptions Options) 275 : DT(DT), TTI(TTI), 276 ScalarizeVariableInsertExtract(Options.ScalarizeVariableInsertExtract), 277 ScalarizeLoadStore(Options.ScalarizeLoadStore), 278 ScalarizeMinBits(Options.ScalarizeMinBits) {} 279 280 bool visit(Function &F); 281 282 // InstVisitor methods. They return true if the instruction was scalarized, 283 // false if nothing changed. 284 bool visitInstruction(Instruction &I) { return false; } 285 bool visitSelectInst(SelectInst &SI); 286 bool visitICmpInst(ICmpInst &ICI); 287 bool visitFCmpInst(FCmpInst &FCI); 288 bool visitUnaryOperator(UnaryOperator &UO); 289 bool visitBinaryOperator(BinaryOperator &BO); 290 bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 291 bool visitCastInst(CastInst &CI); 292 bool visitBitCastInst(BitCastInst &BCI); 293 bool visitInsertElementInst(InsertElementInst &IEI); 294 bool visitExtractElementInst(ExtractElementInst &EEI); 295 bool visitExtractValueInst(ExtractValueInst &EVI); 296 bool visitShuffleVectorInst(ShuffleVectorInst &SVI); 297 bool visitPHINode(PHINode &PHI); 298 bool visitLoadInst(LoadInst &LI); 299 bool visitStoreInst(StoreInst &SI); 300 bool visitCallInst(CallInst &ICI); 301 bool visitFreezeInst(FreezeInst &FI); 302 303 private: 304 Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS); 305 void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS); 306 void replaceUses(Instruction *Op, Value *CV); 307 bool canTransferMetadata(unsigned Kind); 308 void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); 309 std::optional<VectorSplit> getVectorSplit(Type *Ty); 310 std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, 311 const DataLayout &DL); 312 bool finish(); 313 314 template<typename T> bool splitUnary(Instruction &, const T &); 315 template<typename T> bool splitBinary(Instruction &, const T &); 316 317 bool splitCall(CallInst &CI); 318 319 ScatterMap Scattered; 320 GatherList Gathered; 321 bool Scalarized; 322 323 SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; 324 325 DominatorTree *DT; 326 const TargetTransformInfo *TTI; 327 328 const bool ScalarizeVariableInsertExtract; 329 const bool ScalarizeLoadStore; 330 const unsigned ScalarizeMinBits; 331 }; 332 333 class ScalarizerLegacyPass : public FunctionPass { 334 public: 335 static char ID; 336 ScalarizerPassOptions Options; 337 ScalarizerLegacyPass() : FunctionPass(ID), Options() {} 338 ScalarizerLegacyPass(const ScalarizerPassOptions &Options); 339 bool runOnFunction(Function &F) override; 340 void getAnalysisUsage(AnalysisUsage &AU) const override; 341 }; 342 343 } // end anonymous namespace 344 345 ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options) 346 : FunctionPass(ID), Options(Options) {} 347 348 void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { 349 AU.addRequired<DominatorTreeWrapperPass>(); 350 AU.addRequired<TargetTransformInfoWrapperPass>(); 351 AU.addPreserved<DominatorTreeWrapperPass>(); 352 } 353 354 char ScalarizerLegacyPass::ID = 0; 355 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer", 356 "Scalarize vector operations", false, false) 357 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 358 INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", 359 "Scalarize vector operations", false, false) 360 361 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, 362 const VectorSplit &VS, ValueVector *cachePtr) 363 : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) { 364 IsPointer = V->getType()->isPointerTy(); 365 if (!CachePtr) { 366 Tmp.resize(VS.NumFragments, nullptr); 367 } else { 368 assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() || 369 IsPointer) && 370 "Inconsistent vector sizes"); 371 if (VS.NumFragments > CachePtr->size()) 372 CachePtr->resize(VS.NumFragments, nullptr); 373 } 374 } 375 376 // Return fragment Frag, creating a new Value for it if necessary. 377 Value *Scatterer::operator[](unsigned Frag) { 378 ValueVector &CV = CachePtr ? *CachePtr : Tmp; 379 // Try to reuse a previous value. 380 if (CV[Frag]) 381 return CV[Frag]; 382 IRBuilder<> Builder(BB, BBI); 383 if (IsPointer) { 384 if (Frag == 0) 385 CV[Frag] = V; 386 else 387 CV[Frag] = Builder.CreateConstGEP1_32(VS.SplitTy, V, Frag, 388 V->getName() + ".i" + Twine(Frag)); 389 return CV[Frag]; 390 } 391 392 Type *FragmentTy = VS.getFragmentType(Frag); 393 394 if (auto *VecTy = dyn_cast<FixedVectorType>(FragmentTy)) { 395 SmallVector<int> Mask; 396 for (unsigned J = 0; J < VecTy->getNumElements(); ++J) 397 Mask.push_back(Frag * VS.NumPacked + J); 398 CV[Frag] = 399 Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()), Mask, 400 V->getName() + ".i" + Twine(Frag)); 401 } else { 402 // Search through a chain of InsertElementInsts looking for element Frag. 403 // Record other elements in the cache. The new V is still suitable 404 // for all uncached indices. 405 while (true) { 406 InsertElementInst *Insert = dyn_cast<InsertElementInst>(V); 407 if (!Insert) 408 break; 409 ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2)); 410 if (!Idx) 411 break; 412 unsigned J = Idx->getZExtValue(); 413 V = Insert->getOperand(0); 414 if (Frag * VS.NumPacked == J) { 415 CV[Frag] = Insert->getOperand(1); 416 return CV[Frag]; 417 } 418 419 if (VS.NumPacked == 1 && !CV[J]) { 420 // Only cache the first entry we find for each index we're not actively 421 // searching for. This prevents us from going too far up the chain and 422 // caching incorrect entries. 423 CV[J] = Insert->getOperand(1); 424 } 425 } 426 CV[Frag] = Builder.CreateExtractElement(V, Frag * VS.NumPacked, 427 V->getName() + ".i" + Twine(Frag)); 428 } 429 430 return CV[Frag]; 431 } 432 433 bool ScalarizerLegacyPass::runOnFunction(Function &F) { 434 if (skipFunction(F)) 435 return false; 436 437 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 438 const TargetTransformInfo *TTI = 439 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 440 ScalarizerVisitor Impl(DT, TTI, Options); 441 return Impl.visit(F); 442 } 443 444 FunctionPass *llvm::createScalarizerPass(const ScalarizerPassOptions &Options) { 445 return new ScalarizerLegacyPass(Options); 446 } 447 448 bool ScalarizerVisitor::visit(Function &F) { 449 assert(Gathered.empty() && Scattered.empty()); 450 451 Scalarized = false; 452 453 // To ensure we replace gathered components correctly we need to do an ordered 454 // traversal of the basic blocks in the function. 455 ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); 456 for (BasicBlock *BB : RPOT) { 457 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 458 Instruction *I = &*II; 459 bool Done = InstVisitor::visit(I); 460 ++II; 461 if (Done && I->getType()->isVoidTy()) 462 I->eraseFromParent(); 463 } 464 } 465 return finish(); 466 } 467 468 // Return a scattered form of V that can be accessed by Point. V must be a 469 // vector or a pointer to a vector. 470 Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, 471 const VectorSplit &VS) { 472 if (Argument *VArg = dyn_cast<Argument>(V)) { 473 // Put the scattered form of arguments in the entry block, 474 // so that it can be used everywhere. 475 Function *F = VArg->getParent(); 476 BasicBlock *BB = &F->getEntryBlock(); 477 return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]); 478 } 479 if (Instruction *VOp = dyn_cast<Instruction>(V)) { 480 // When scalarizing PHI nodes we might try to examine/rewrite InsertElement 481 // nodes in predecessors. If those predecessors are unreachable from entry, 482 // then the IR in those blocks could have unexpected properties resulting in 483 // infinite loops in Scatterer::operator[]. By simply treating values 484 // originating from instructions in unreachable blocks as undef we do not 485 // need to analyse them further. 486 if (!DT->isReachableFromEntry(VOp->getParent())) 487 return Scatterer(Point->getParent(), Point->getIterator(), 488 PoisonValue::get(V->getType()), VS); 489 // Put the scattered form of an instruction directly after the 490 // instruction, skipping over PHI nodes and debug intrinsics. 491 BasicBlock *BB = VOp->getParent(); 492 return Scatterer( 493 BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, VS, 494 &Scattered[{V, VS.SplitTy}]); 495 } 496 // In the fallback case, just put the scattered before Point and 497 // keep the result local to Point. 498 return Scatterer(Point->getParent(), Point->getIterator(), V, VS); 499 } 500 501 // Replace Op with the gathered form of the components in CV. Defer the 502 // deletion of Op and creation of the gathered form to the end of the pass, 503 // so that we can avoid creating the gathered form if all uses of Op are 504 // replaced with uses of CV. 505 void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV, 506 const VectorSplit &VS) { 507 transferMetadataAndIRFlags(Op, CV); 508 509 // If we already have a scattered form of Op (created from ExtractElements 510 // of Op itself), replace them with the new form. 511 ValueVector &SV = Scattered[{Op, VS.SplitTy}]; 512 if (!SV.empty()) { 513 for (unsigned I = 0, E = SV.size(); I != E; ++I) { 514 Value *V = SV[I]; 515 if (V == nullptr || SV[I] == CV[I]) 516 continue; 517 518 Instruction *Old = cast<Instruction>(V); 519 if (isa<Instruction>(CV[I])) 520 CV[I]->takeName(Old); 521 Old->replaceAllUsesWith(CV[I]); 522 PotentiallyDeadInstrs.emplace_back(Old); 523 } 524 } 525 SV = CV; 526 Gathered.push_back(GatherList::value_type(Op, &SV)); 527 } 528 529 // Replace Op with CV and collect Op has a potentially dead instruction. 530 void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) { 531 if (CV != Op) { 532 Op->replaceAllUsesWith(CV); 533 PotentiallyDeadInstrs.emplace_back(Op); 534 Scalarized = true; 535 } 536 } 537 538 // Return true if it is safe to transfer the given metadata tag from 539 // vector to scalar instructions. 540 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { 541 return (Tag == LLVMContext::MD_tbaa 542 || Tag == LLVMContext::MD_fpmath 543 || Tag == LLVMContext::MD_tbaa_struct 544 || Tag == LLVMContext::MD_invariant_load 545 || Tag == LLVMContext::MD_alias_scope 546 || Tag == LLVMContext::MD_noalias 547 || Tag == LLVMContext::MD_mem_parallel_loop_access 548 || Tag == LLVMContext::MD_access_group); 549 } 550 551 // Transfer metadata from Op to the instructions in CV if it is known 552 // to be safe to do so. 553 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, 554 const ValueVector &CV) { 555 SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; 556 Op->getAllMetadataOtherThanDebugLoc(MDs); 557 for (Value *V : CV) { 558 if (Instruction *New = dyn_cast<Instruction>(V)) { 559 for (const auto &MD : MDs) 560 if (canTransferMetadata(MD.first)) 561 New->setMetadata(MD.first, MD.second); 562 New->copyIRFlags(Op); 563 if (Op->getDebugLoc() && !New->getDebugLoc()) 564 New->setDebugLoc(Op->getDebugLoc()); 565 } 566 } 567 } 568 569 // Determine how Ty is split, if at all. 570 std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) { 571 VectorSplit Split; 572 Split.VecTy = dyn_cast<FixedVectorType>(Ty); 573 if (!Split.VecTy) 574 return {}; 575 576 unsigned NumElems = Split.VecTy->getNumElements(); 577 Type *ElemTy = Split.VecTy->getElementType(); 578 579 if (NumElems == 1 || ElemTy->isPointerTy() || 580 2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) { 581 Split.NumPacked = 1; 582 Split.NumFragments = NumElems; 583 Split.SplitTy = ElemTy; 584 } else { 585 Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits(); 586 if (Split.NumPacked >= NumElems) 587 return {}; 588 589 Split.NumFragments = divideCeil(NumElems, Split.NumPacked); 590 Split.SplitTy = FixedVectorType::get(ElemTy, Split.NumPacked); 591 592 unsigned RemainderElems = NumElems % Split.NumPacked; 593 if (RemainderElems > 1) 594 Split.RemainderTy = FixedVectorType::get(ElemTy, RemainderElems); 595 else if (RemainderElems == 1) 596 Split.RemainderTy = ElemTy; 597 } 598 599 return Split; 600 } 601 602 // Try to fill in Layout from Ty, returning true on success. Alignment is 603 // the alignment of the vector, or std::nullopt if the ABI default should be 604 // used. 605 std::optional<VectorLayout> 606 ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, 607 const DataLayout &DL) { 608 std::optional<VectorSplit> VS = getVectorSplit(Ty); 609 if (!VS) 610 return {}; 611 612 VectorLayout Layout; 613 Layout.VS = *VS; 614 // Check that we're dealing with full-byte fragments. 615 if (!DL.typeSizeEqualsStoreSize(VS->SplitTy) || 616 (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(VS->RemainderTy))) 617 return {}; 618 Layout.VecAlign = Alignment; 619 Layout.SplitSize = DL.getTypeStoreSize(VS->SplitTy); 620 return Layout; 621 } 622 623 // Scalarize one-operand instruction I, using Split(Builder, X, Name) 624 // to create an instruction like I with operand X and name Name. 625 template<typename Splitter> 626 bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { 627 std::optional<VectorSplit> VS = getVectorSplit(I.getType()); 628 if (!VS) 629 return false; 630 631 std::optional<VectorSplit> OpVS; 632 if (I.getOperand(0)->getType() == I.getType()) { 633 OpVS = VS; 634 } else { 635 OpVS = getVectorSplit(I.getOperand(0)->getType()); 636 if (!OpVS || VS->NumPacked != OpVS->NumPacked) 637 return false; 638 } 639 640 IRBuilder<> Builder(&I); 641 Scatterer Op = scatter(&I, I.getOperand(0), *OpVS); 642 assert(Op.size() == VS->NumFragments && "Mismatched unary operation"); 643 ValueVector Res; 644 Res.resize(VS->NumFragments); 645 for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) 646 Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag)); 647 gather(&I, Res, *VS); 648 return true; 649 } 650 651 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name) 652 // to create an instruction like I with operands X and Y and name Name. 653 template<typename Splitter> 654 bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { 655 std::optional<VectorSplit> VS = getVectorSplit(I.getType()); 656 if (!VS) 657 return false; 658 659 std::optional<VectorSplit> OpVS; 660 if (I.getOperand(0)->getType() == I.getType()) { 661 OpVS = VS; 662 } else { 663 OpVS = getVectorSplit(I.getOperand(0)->getType()); 664 if (!OpVS || VS->NumPacked != OpVS->NumPacked) 665 return false; 666 } 667 668 IRBuilder<> Builder(&I); 669 Scatterer VOp0 = scatter(&I, I.getOperand(0), *OpVS); 670 Scatterer VOp1 = scatter(&I, I.getOperand(1), *OpVS); 671 assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation"); 672 assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation"); 673 ValueVector Res; 674 Res.resize(VS->NumFragments); 675 for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) { 676 Value *Op0 = VOp0[Frag]; 677 Value *Op1 = VOp1[Frag]; 678 Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag)); 679 } 680 gather(&I, Res, *VS); 681 return true; 682 } 683 684 /// If a call to a vector typed intrinsic function, split into a scalar call per 685 /// element if possible for the intrinsic. 686 bool ScalarizerVisitor::splitCall(CallInst &CI) { 687 Type *CallType = CI.getType(); 688 bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(CallType); 689 std::optional<VectorSplit> VS; 690 if (AreAllVectorsOfMatchingSize) 691 VS = getVectorSplit(CallType->getContainedType(0)); 692 else 693 VS = getVectorSplit(CallType); 694 if (!VS) 695 return false; 696 697 Function *F = CI.getCalledFunction(); 698 if (!F) 699 return false; 700 701 Intrinsic::ID ID = F->getIntrinsicID(); 702 703 if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI)) 704 return false; 705 706 // unsigned NumElems = VT->getNumElements(); 707 unsigned NumArgs = CI.arg_size(); 708 709 ValueVector ScalarOperands(NumArgs); 710 SmallVector<Scatterer, 8> Scattered(NumArgs); 711 SmallVector<int> OverloadIdx(NumArgs, -1); 712 713 SmallVector<llvm::Type *, 3> Tys; 714 // Add return type if intrinsic is overloaded on it. 715 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI)) 716 Tys.push_back(VS->SplitTy); 717 718 if (AreAllVectorsOfMatchingSize) { 719 for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) { 720 std::optional<VectorSplit> CurrVS = 721 getVectorSplit(cast<FixedVectorType>(CallType->getContainedType(I))); 722 // This case does not seem to happen, but it is possible for 723 // VectorSplit.NumPacked >= NumElems. If that happens a VectorSplit 724 // is not returned and we will bailout of handling this call. 725 // The secondary bailout case is if NumPacked does not match. 726 // This can happen if ScalarizeMinBits is not set to the default. 727 // This means with certain ScalarizeMinBits intrinsics like frexp 728 // will only scalarize when the struct elements have the same bitness. 729 if (!CurrVS || CurrVS->NumPacked != VS->NumPacked) 730 return false; 731 if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI)) 732 Tys.push_back(CurrVS->SplitTy); 733 } 734 } 735 // Assumes that any vector type has the same number of elements as the return 736 // vector type, which is true for all current intrinsics. 737 for (unsigned I = 0; I != NumArgs; ++I) { 738 Value *OpI = CI.getOperand(I); 739 if ([[maybe_unused]] auto *OpVecTy = 740 dyn_cast<FixedVectorType>(OpI->getType())) { 741 assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements()); 742 std::optional<VectorSplit> OpVS = getVectorSplit(OpI->getType()); 743 if (!OpVS || OpVS->NumPacked != VS->NumPacked) { 744 // The natural split of the operand doesn't match the result. This could 745 // happen if the vector elements are different and the ScalarizeMinBits 746 // option is used. 747 // 748 // We could in principle handle this case as well, at the cost of 749 // complicating the scattering machinery to support multiple scattering 750 // granularities for a single value. 751 return false; 752 } 753 754 Scattered[I] = scatter(&CI, OpI, *OpVS); 755 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) { 756 OverloadIdx[I] = Tys.size(); 757 Tys.push_back(OpVS->SplitTy); 758 } 759 } else { 760 ScalarOperands[I] = OpI; 761 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) 762 Tys.push_back(OpI->getType()); 763 } 764 } 765 766 ValueVector Res(VS->NumFragments); 767 ValueVector ScalarCallOps(NumArgs); 768 769 Function *NewIntrin = 770 Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys); 771 IRBuilder<> Builder(&CI); 772 773 // Perform actual scalarization, taking care to preserve any scalar operands. 774 for (unsigned I = 0; I < VS->NumFragments; ++I) { 775 bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy; 776 ScalarCallOps.clear(); 777 778 if (IsRemainder) 779 Tys[0] = VS->RemainderTy; 780 781 for (unsigned J = 0; J != NumArgs; ++J) { 782 if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) { 783 ScalarCallOps.push_back(ScalarOperands[J]); 784 } else { 785 ScalarCallOps.push_back(Scattered[J][I]); 786 if (IsRemainder && OverloadIdx[J] >= 0) 787 Tys[OverloadIdx[J]] = Scattered[J][I]->getType(); 788 } 789 } 790 791 if (IsRemainder) 792 NewIntrin = Intrinsic::getOrInsertDeclaration(F->getParent(), ID, Tys); 793 794 Res[I] = Builder.CreateCall(NewIntrin, ScalarCallOps, 795 CI.getName() + ".i" + Twine(I)); 796 } 797 798 gather(&CI, Res, *VS); 799 return true; 800 } 801 802 bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { 803 std::optional<VectorSplit> VS = getVectorSplit(SI.getType()); 804 if (!VS) 805 return false; 806 807 std::optional<VectorSplit> CondVS; 808 if (isa<FixedVectorType>(SI.getCondition()->getType())) { 809 CondVS = getVectorSplit(SI.getCondition()->getType()); 810 if (!CondVS || CondVS->NumPacked != VS->NumPacked) { 811 // This happens when ScalarizeMinBits is used. 812 return false; 813 } 814 } 815 816 IRBuilder<> Builder(&SI); 817 Scatterer VOp1 = scatter(&SI, SI.getOperand(1), *VS); 818 Scatterer VOp2 = scatter(&SI, SI.getOperand(2), *VS); 819 assert(VOp1.size() == VS->NumFragments && "Mismatched select"); 820 assert(VOp2.size() == VS->NumFragments && "Mismatched select"); 821 ValueVector Res; 822 Res.resize(VS->NumFragments); 823 824 if (CondVS) { 825 Scatterer VOp0 = scatter(&SI, SI.getOperand(0), *CondVS); 826 assert(VOp0.size() == CondVS->NumFragments && "Mismatched select"); 827 for (unsigned I = 0; I < VS->NumFragments; ++I) { 828 Value *Op0 = VOp0[I]; 829 Value *Op1 = VOp1[I]; 830 Value *Op2 = VOp2[I]; 831 Res[I] = Builder.CreateSelect(Op0, Op1, Op2, 832 SI.getName() + ".i" + Twine(I)); 833 } 834 } else { 835 Value *Op0 = SI.getOperand(0); 836 for (unsigned I = 0; I < VS->NumFragments; ++I) { 837 Value *Op1 = VOp1[I]; 838 Value *Op2 = VOp2[I]; 839 Res[I] = Builder.CreateSelect(Op0, Op1, Op2, 840 SI.getName() + ".i" + Twine(I)); 841 } 842 } 843 gather(&SI, Res, *VS); 844 return true; 845 } 846 847 bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) { 848 return splitBinary(ICI, ICmpSplitter(ICI)); 849 } 850 851 bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) { 852 return splitBinary(FCI, FCmpSplitter(FCI)); 853 } 854 855 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) { 856 return splitUnary(UO, UnarySplitter(UO)); 857 } 858 859 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) { 860 return splitBinary(BO, BinarySplitter(BO)); 861 } 862 863 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { 864 std::optional<VectorSplit> VS = getVectorSplit(GEPI.getType()); 865 if (!VS) 866 return false; 867 868 IRBuilder<> Builder(&GEPI); 869 unsigned NumIndices = GEPI.getNumIndices(); 870 871 // The base pointer and indices might be scalar even if it's a vector GEP. 872 SmallVector<Value *, 8> ScalarOps{1 + NumIndices}; 873 SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices}; 874 875 for (unsigned I = 0; I < 1 + NumIndices; ++I) { 876 if (auto *VecTy = 877 dyn_cast<FixedVectorType>(GEPI.getOperand(I)->getType())) { 878 std::optional<VectorSplit> OpVS = getVectorSplit(VecTy); 879 if (!OpVS || OpVS->NumPacked != VS->NumPacked) { 880 // This can happen when ScalarizeMinBits is used. 881 return false; 882 } 883 ScatterOps[I] = scatter(&GEPI, GEPI.getOperand(I), *OpVS); 884 } else { 885 ScalarOps[I] = GEPI.getOperand(I); 886 } 887 } 888 889 ValueVector Res; 890 Res.resize(VS->NumFragments); 891 for (unsigned I = 0; I < VS->NumFragments; ++I) { 892 SmallVector<Value *, 8> SplitOps; 893 SplitOps.resize(1 + NumIndices); 894 for (unsigned J = 0; J < 1 + NumIndices; ++J) { 895 if (ScalarOps[J]) 896 SplitOps[J] = ScalarOps[J]; 897 else 898 SplitOps[J] = ScatterOps[J][I]; 899 } 900 Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), SplitOps[0], 901 ArrayRef(SplitOps).drop_front(), 902 GEPI.getName() + ".i" + Twine(I)); 903 if (GEPI.isInBounds()) 904 if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I])) 905 NewGEPI->setIsInBounds(); 906 } 907 gather(&GEPI, Res, *VS); 908 return true; 909 } 910 911 bool ScalarizerVisitor::visitCastInst(CastInst &CI) { 912 std::optional<VectorSplit> DestVS = getVectorSplit(CI.getDestTy()); 913 if (!DestVS) 914 return false; 915 916 std::optional<VectorSplit> SrcVS = getVectorSplit(CI.getSrcTy()); 917 if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked) 918 return false; 919 920 IRBuilder<> Builder(&CI); 921 Scatterer Op0 = scatter(&CI, CI.getOperand(0), *SrcVS); 922 assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast"); 923 ValueVector Res; 924 Res.resize(DestVS->NumFragments); 925 for (unsigned I = 0; I < DestVS->NumFragments; ++I) 926 Res[I] = 927 Builder.CreateCast(CI.getOpcode(), Op0[I], DestVS->getFragmentType(I), 928 CI.getName() + ".i" + Twine(I)); 929 gather(&CI, Res, *DestVS); 930 return true; 931 } 932 933 bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { 934 std::optional<VectorSplit> DstVS = getVectorSplit(BCI.getDestTy()); 935 std::optional<VectorSplit> SrcVS = getVectorSplit(BCI.getSrcTy()); 936 if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy) 937 return false; 938 939 const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy(); 940 941 // Vectors of pointers are always fully scalarized. 942 assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1)); 943 944 IRBuilder<> Builder(&BCI); 945 Scatterer Op0 = scatter(&BCI, BCI.getOperand(0), *SrcVS); 946 ValueVector Res; 947 Res.resize(DstVS->NumFragments); 948 949 unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits(); 950 unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits(); 951 952 if (isPointerTy || DstSplitBits == SrcSplitBits) { 953 assert(DstVS->NumFragments == SrcVS->NumFragments); 954 for (unsigned I = 0; I < DstVS->NumFragments; ++I) { 955 Res[I] = Builder.CreateBitCast(Op0[I], DstVS->getFragmentType(I), 956 BCI.getName() + ".i" + Twine(I)); 957 } 958 } else if (SrcSplitBits % DstSplitBits == 0) { 959 // Convert each source fragment to the same-sized destination vector and 960 // then scatter the result to the destination. 961 VectorSplit MidVS; 962 MidVS.NumPacked = DstVS->NumPacked; 963 MidVS.NumFragments = SrcSplitBits / DstSplitBits; 964 MidVS.VecTy = FixedVectorType::get(DstVS->VecTy->getElementType(), 965 MidVS.NumPacked * MidVS.NumFragments); 966 MidVS.SplitTy = DstVS->SplitTy; 967 968 unsigned ResI = 0; 969 for (unsigned I = 0; I < SrcVS->NumFragments; ++I) { 970 Value *V = Op0[I]; 971 972 // Look through any existing bitcasts before converting to <N x t2>. 973 // In the best case, the resulting conversion might be a no-op. 974 Instruction *VI; 975 while ((VI = dyn_cast<Instruction>(V)) && 976 VI->getOpcode() == Instruction::BitCast) 977 V = VI->getOperand(0); 978 979 V = Builder.CreateBitCast(V, MidVS.VecTy, V->getName() + ".cast"); 980 981 Scatterer Mid = scatter(&BCI, V, MidVS); 982 for (unsigned J = 0; J < MidVS.NumFragments; ++J) 983 Res[ResI++] = Mid[J]; 984 } 985 } else if (DstSplitBits % SrcSplitBits == 0) { 986 // Gather enough source fragments to make up a destination fragment and 987 // then convert to the destination type. 988 VectorSplit MidVS; 989 MidVS.NumFragments = DstSplitBits / SrcSplitBits; 990 MidVS.NumPacked = SrcVS->NumPacked; 991 MidVS.VecTy = FixedVectorType::get(SrcVS->VecTy->getElementType(), 992 MidVS.NumPacked * MidVS.NumFragments); 993 MidVS.SplitTy = SrcVS->SplitTy; 994 995 unsigned SrcI = 0; 996 SmallVector<Value *, 8> ConcatOps; 997 ConcatOps.resize(MidVS.NumFragments); 998 for (unsigned I = 0; I < DstVS->NumFragments; ++I) { 999 for (unsigned J = 0; J < MidVS.NumFragments; ++J) 1000 ConcatOps[J] = Op0[SrcI++]; 1001 Value *V = concatenate(Builder, ConcatOps, MidVS, 1002 BCI.getName() + ".i" + Twine(I)); 1003 Res[I] = Builder.CreateBitCast(V, DstVS->getFragmentType(I), 1004 BCI.getName() + ".i" + Twine(I)); 1005 } 1006 } else { 1007 return false; 1008 } 1009 1010 gather(&BCI, Res, *DstVS); 1011 return true; 1012 } 1013 1014 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { 1015 std::optional<VectorSplit> VS = getVectorSplit(IEI.getType()); 1016 if (!VS) 1017 return false; 1018 1019 IRBuilder<> Builder(&IEI); 1020 Scatterer Op0 = scatter(&IEI, IEI.getOperand(0), *VS); 1021 Value *NewElt = IEI.getOperand(1); 1022 Value *InsIdx = IEI.getOperand(2); 1023 1024 ValueVector Res; 1025 Res.resize(VS->NumFragments); 1026 1027 if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) { 1028 unsigned Idx = CI->getZExtValue(); 1029 unsigned Fragment = Idx / VS->NumPacked; 1030 for (unsigned I = 0; I < VS->NumFragments; ++I) { 1031 if (I == Fragment) { 1032 bool IsPacked = VS->NumPacked > 1; 1033 if (Fragment == VS->NumFragments - 1 && VS->RemainderTy && 1034 !VS->RemainderTy->isVectorTy()) 1035 IsPacked = false; 1036 if (IsPacked) { 1037 Res[I] = 1038 Builder.CreateInsertElement(Op0[I], NewElt, Idx % VS->NumPacked); 1039 } else { 1040 Res[I] = NewElt; 1041 } 1042 } else { 1043 Res[I] = Op0[I]; 1044 } 1045 } 1046 } else { 1047 // Never split a variable insertelement that isn't fully scalarized. 1048 if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1) 1049 return false; 1050 1051 for (unsigned I = 0; I < VS->NumFragments; ++I) { 1052 Value *ShouldReplace = 1053 Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I), 1054 InsIdx->getName() + ".is." + Twine(I)); 1055 Value *OldElt = Op0[I]; 1056 Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt, 1057 IEI.getName() + ".i" + Twine(I)); 1058 } 1059 } 1060 1061 gather(&IEI, Res, *VS); 1062 return true; 1063 } 1064 1065 bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) { 1066 Value *Op = EVI.getOperand(0); 1067 Type *OpTy = Op->getType(); 1068 ValueVector Res; 1069 if (!isStructOfMatchingFixedVectors(OpTy)) 1070 return false; 1071 if (CallInst *CI = dyn_cast<CallInst>(Op)) { 1072 Function *F = CI->getCalledFunction(); 1073 if (!F) 1074 return false; 1075 Intrinsic::ID ID = F->getIntrinsicID(); 1076 if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI)) 1077 return false; 1078 // Note: Fall through means Operand is a`CallInst` and it is defined in 1079 // `isTriviallyScalarizable`. 1080 } else 1081 return false; 1082 Type *VecType = cast<FixedVectorType>(OpTy->getContainedType(0)); 1083 std::optional<VectorSplit> VS = getVectorSplit(VecType); 1084 if (!VS) 1085 return false; 1086 IRBuilder<> Builder(&EVI); 1087 Scatterer Op0 = scatter(&EVI, Op, *VS); 1088 assert(!EVI.getIndices().empty() && "Make sure an index exists"); 1089 // Note for our use case we only care about the top level index. 1090 unsigned Index = EVI.getIndices()[0]; 1091 for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) { 1092 Value *ResElem = Builder.CreateExtractValue( 1093 Op0[OpIdx], Index, EVI.getName() + ".elem" + Twine(Index)); 1094 Res.push_back(ResElem); 1095 } 1096 1097 gather(&EVI, Res, *VS); 1098 return true; 1099 } 1100 1101 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { 1102 std::optional<VectorSplit> VS = getVectorSplit(EEI.getOperand(0)->getType()); 1103 if (!VS) 1104 return false; 1105 1106 IRBuilder<> Builder(&EEI); 1107 Scatterer Op0 = scatter(&EEI, EEI.getOperand(0), *VS); 1108 Value *ExtIdx = EEI.getOperand(1); 1109 1110 if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { 1111 unsigned Idx = CI->getZExtValue(); 1112 unsigned Fragment = Idx / VS->NumPacked; 1113 Value *Res = Op0[Fragment]; 1114 bool IsPacked = VS->NumPacked > 1; 1115 if (Fragment == VS->NumFragments - 1 && VS->RemainderTy && 1116 !VS->RemainderTy->isVectorTy()) 1117 IsPacked = false; 1118 if (IsPacked) 1119 Res = Builder.CreateExtractElement(Res, Idx % VS->NumPacked); 1120 replaceUses(&EEI, Res); 1121 return true; 1122 } 1123 1124 // Never split a variable extractelement that isn't fully scalarized. 1125 if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1) 1126 return false; 1127 1128 Value *Res = PoisonValue::get(VS->VecTy->getElementType()); 1129 for (unsigned I = 0; I < VS->NumFragments; ++I) { 1130 Value *ShouldExtract = 1131 Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), 1132 ExtIdx->getName() + ".is." + Twine(I)); 1133 Value *Elt = Op0[I]; 1134 Res = Builder.CreateSelect(ShouldExtract, Elt, Res, 1135 EEI.getName() + ".upto" + Twine(I)); 1136 } 1137 replaceUses(&EEI, Res); 1138 return true; 1139 } 1140 1141 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { 1142 std::optional<VectorSplit> VS = getVectorSplit(SVI.getType()); 1143 std::optional<VectorSplit> VSOp = 1144 getVectorSplit(SVI.getOperand(0)->getType()); 1145 if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1) 1146 return false; 1147 1148 Scatterer Op0 = scatter(&SVI, SVI.getOperand(0), *VSOp); 1149 Scatterer Op1 = scatter(&SVI, SVI.getOperand(1), *VSOp); 1150 ValueVector Res; 1151 Res.resize(VS->NumFragments); 1152 1153 for (unsigned I = 0; I < VS->NumFragments; ++I) { 1154 int Selector = SVI.getMaskValue(I); 1155 if (Selector < 0) 1156 Res[I] = PoisonValue::get(VS->VecTy->getElementType()); 1157 else if (unsigned(Selector) < Op0.size()) 1158 Res[I] = Op0[Selector]; 1159 else 1160 Res[I] = Op1[Selector - Op0.size()]; 1161 } 1162 gather(&SVI, Res, *VS); 1163 return true; 1164 } 1165 1166 bool ScalarizerVisitor::visitPHINode(PHINode &PHI) { 1167 std::optional<VectorSplit> VS = getVectorSplit(PHI.getType()); 1168 if (!VS) 1169 return false; 1170 1171 IRBuilder<> Builder(&PHI); 1172 ValueVector Res; 1173 Res.resize(VS->NumFragments); 1174 1175 unsigned NumOps = PHI.getNumOperands(); 1176 for (unsigned I = 0; I < VS->NumFragments; ++I) { 1177 Res[I] = Builder.CreatePHI(VS->getFragmentType(I), NumOps, 1178 PHI.getName() + ".i" + Twine(I)); 1179 } 1180 1181 for (unsigned I = 0; I < NumOps; ++I) { 1182 Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I), *VS); 1183 BasicBlock *IncomingBlock = PHI.getIncomingBlock(I); 1184 for (unsigned J = 0; J < VS->NumFragments; ++J) 1185 cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock); 1186 } 1187 gather(&PHI, Res, *VS); 1188 return true; 1189 } 1190 1191 bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { 1192 if (!ScalarizeLoadStore) 1193 return false; 1194 if (!LI.isSimple()) 1195 return false; 1196 1197 std::optional<VectorLayout> Layout = getVectorLayout( 1198 LI.getType(), LI.getAlign(), LI.getDataLayout()); 1199 if (!Layout) 1200 return false; 1201 1202 IRBuilder<> Builder(&LI); 1203 Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), Layout->VS); 1204 ValueVector Res; 1205 Res.resize(Layout->VS.NumFragments); 1206 1207 for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) { 1208 Res[I] = Builder.CreateAlignedLoad(Layout->VS.getFragmentType(I), Ptr[I], 1209 Align(Layout->getFragmentAlign(I)), 1210 LI.getName() + ".i" + Twine(I)); 1211 } 1212 gather(&LI, Res, Layout->VS); 1213 return true; 1214 } 1215 1216 bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { 1217 if (!ScalarizeLoadStore) 1218 return false; 1219 if (!SI.isSimple()) 1220 return false; 1221 1222 Value *FullValue = SI.getValueOperand(); 1223 std::optional<VectorLayout> Layout = getVectorLayout( 1224 FullValue->getType(), SI.getAlign(), SI.getDataLayout()); 1225 if (!Layout) 1226 return false; 1227 1228 IRBuilder<> Builder(&SI); 1229 Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), Layout->VS); 1230 Scatterer VVal = scatter(&SI, FullValue, Layout->VS); 1231 1232 ValueVector Stores; 1233 Stores.resize(Layout->VS.NumFragments); 1234 for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) { 1235 Value *Val = VVal[I]; 1236 Value *Ptr = VPtr[I]; 1237 Stores[I] = 1238 Builder.CreateAlignedStore(Val, Ptr, Layout->getFragmentAlign(I)); 1239 } 1240 transferMetadataAndIRFlags(&SI, Stores); 1241 return true; 1242 } 1243 1244 bool ScalarizerVisitor::visitCallInst(CallInst &CI) { 1245 return splitCall(CI); 1246 } 1247 1248 bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) { 1249 return splitUnary(FI, [](IRBuilder<> &Builder, Value *Op, const Twine &Name) { 1250 return Builder.CreateFreeze(Op, Name); 1251 }); 1252 } 1253 1254 // Delete the instructions that we scalarized. If a full vector result 1255 // is still needed, recreate it using InsertElements. 1256 bool ScalarizerVisitor::finish() { 1257 // The presence of data in Gathered or Scattered indicates changes 1258 // made to the Function. 1259 if (Gathered.empty() && Scattered.empty() && !Scalarized) 1260 return false; 1261 for (const auto &GMI : Gathered) { 1262 Instruction *Op = GMI.first; 1263 ValueVector &CV = *GMI.second; 1264 if (!Op->use_empty()) { 1265 // The value is still needed, so recreate it using a series of 1266 // insertelements and/or shufflevectors. 1267 Value *Res; 1268 if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) { 1269 BasicBlock *BB = Op->getParent(); 1270 IRBuilder<> Builder(Op); 1271 if (isa<PHINode>(Op)) 1272 Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); 1273 1274 VectorSplit VS = *getVectorSplit(Ty); 1275 assert(VS.NumFragments == CV.size()); 1276 1277 Res = concatenate(Builder, CV, VS, Op->getName()); 1278 1279 Res->takeName(Op); 1280 } else if (auto *Ty = dyn_cast<StructType>(Op->getType())) { 1281 BasicBlock *BB = Op->getParent(); 1282 IRBuilder<> Builder(Op); 1283 if (isa<PHINode>(Op)) 1284 Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); 1285 1286 // Iterate over each element in the struct 1287 unsigned NumOfStructElements = Ty->getNumElements(); 1288 SmallVector<ValueVector, 4> ElemCV(NumOfStructElements); 1289 for (unsigned I = 0; I < NumOfStructElements; ++I) { 1290 for (auto *CVelem : CV) { 1291 Value *Elem = Builder.CreateExtractValue( 1292 CVelem, I, Op->getName() + ".elem" + Twine(I)); 1293 ElemCV[I].push_back(Elem); 1294 } 1295 } 1296 Res = PoisonValue::get(Ty); 1297 for (unsigned I = 0; I < NumOfStructElements; ++I) { 1298 Type *ElemTy = Ty->getElementType(I); 1299 assert(isa<FixedVectorType>(ElemTy) && 1300 "Only Structs of all FixedVectorType supported"); 1301 VectorSplit VS = *getVectorSplit(ElemTy); 1302 assert(VS.NumFragments == CV.size()); 1303 1304 Value *ConcatenatedVector = 1305 concatenate(Builder, ElemCV[I], VS, Op->getName()); 1306 Res = Builder.CreateInsertValue(Res, ConcatenatedVector, I, 1307 Op->getName() + ".insert"); 1308 } 1309 } else { 1310 assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); 1311 Res = CV[0]; 1312 if (Op == Res) 1313 continue; 1314 } 1315 Op->replaceAllUsesWith(Res); 1316 } 1317 PotentiallyDeadInstrs.emplace_back(Op); 1318 } 1319 Gathered.clear(); 1320 Scattered.clear(); 1321 Scalarized = false; 1322 1323 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); 1324 1325 return true; 1326 } 1327 1328 PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) { 1329 DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); 1330 const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); 1331 ScalarizerVisitor Impl(DT, TTI, Options); 1332 bool Changed = Impl.visit(F); 1333 PreservedAnalyses PA; 1334 PA.preserve<DominatorTreeAnalysis>(); 1335 return Changed ? PA : PreservedAnalyses::all(); 1336 } 1337