1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/AssumptionCache.h" 21 #include "llvm/Analysis/BasicAliasAnalysis.h" 22 #include "llvm/Analysis/GlobalsModRef.h" 23 #include "llvm/Analysis/Loads.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/ValueTracking.h" 26 #include "llvm/Analysis/VectorUtils.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/PatternMatch.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Transforms/Utils/Local.h" 33 #include "llvm/Transforms/Utils/LoopUtils.h" 34 #include <numeric> 35 #include <queue> 36 37 #define DEBUG_TYPE "vector-combine" 38 #include "llvm/Transforms/Utils/InstructionWorklist.h" 39 40 using namespace llvm; 41 using namespace llvm::PatternMatch; 42 43 STATISTIC(NumVecLoad, "Number of vector loads formed"); 44 STATISTIC(NumVecCmp, "Number of vector compares formed"); 45 STATISTIC(NumVecBO, "Number of vector binops formed"); 46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); 47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); 48 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 49 STATISTIC(NumScalarCmp, "Number of scalar compares formed"); 50 51 static cl::opt<bool> DisableVectorCombine( 52 "disable-vector-combine", cl::init(false), cl::Hidden, 53 cl::desc("Disable all vector combine transforms")); 54 55 static cl::opt<bool> DisableBinopExtractShuffle( 56 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 57 cl::desc("Disable binop extract to shuffle transforms")); 58 59 static cl::opt<unsigned> MaxInstrsToScan( 60 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, 61 cl::desc("Max number of instructions to scan for vector combining.")); 62 63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max(); 64 65 namespace { 66 class VectorCombine { 67 public: 68 VectorCombine(Function &F, const TargetTransformInfo &TTI, 69 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, 70 const DataLayout *DL, TTI::TargetCostKind CostKind, 71 bool TryEarlyFoldsOnly) 72 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL), 73 CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {} 74 75 bool run(); 76 77 private: 78 Function &F; 79 IRBuilder<> Builder; 80 const TargetTransformInfo &TTI; 81 const DominatorTree &DT; 82 AAResults &AA; 83 AssumptionCache &AC; 84 const DataLayout *DL; 85 TTI::TargetCostKind CostKind; 86 87 /// If true, only perform beneficial early IR transforms. Do not introduce new 88 /// vector operations. 89 bool TryEarlyFoldsOnly; 90 91 InstructionWorklist Worklist; 92 93 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction" 94 // parameter. That should be updated to specific sub-classes because the 95 // run loop was changed to dispatch on opcode. 96 bool vectorizeLoadInsert(Instruction &I); 97 bool widenSubvectorLoad(Instruction &I); 98 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, 99 ExtractElementInst *Ext1, 100 unsigned PreferredExtractIndex) const; 101 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 102 const Instruction &I, 103 ExtractElementInst *&ConvertToShuffle, 104 unsigned PreferredExtractIndex); 105 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 106 Instruction &I); 107 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 108 Instruction &I); 109 bool foldExtractExtract(Instruction &I); 110 bool foldInsExtFNeg(Instruction &I); 111 bool foldInsExtVectorToShuffle(Instruction &I); 112 bool foldBitcastShuffle(Instruction &I); 113 bool scalarizeBinopOrCmp(Instruction &I); 114 bool scalarizeVPIntrinsic(Instruction &I); 115 bool foldExtractedCmps(Instruction &I); 116 bool foldSingleElementStore(Instruction &I); 117 bool scalarizeLoadExtract(Instruction &I); 118 bool foldConcatOfBoolMasks(Instruction &I); 119 bool foldPermuteOfBinops(Instruction &I); 120 bool foldShuffleOfBinops(Instruction &I); 121 bool foldShuffleOfCastops(Instruction &I); 122 bool foldShuffleOfShuffles(Instruction &I); 123 bool foldShuffleOfIntrinsics(Instruction &I); 124 bool foldShuffleToIdentity(Instruction &I); 125 bool foldShuffleFromReductions(Instruction &I); 126 bool foldCastFromReductions(Instruction &I); 127 bool foldSelectShuffle(Instruction &I, bool FromReduction = false); 128 bool shrinkType(Instruction &I); 129 130 void replaceValue(Value &Old, Value &New) { 131 Old.replaceAllUsesWith(&New); 132 if (auto *NewI = dyn_cast<Instruction>(&New)) { 133 New.takeName(&Old); 134 Worklist.pushUsersToWorkList(*NewI); 135 Worklist.pushValue(NewI); 136 } 137 Worklist.pushValue(&Old); 138 } 139 140 void eraseInstruction(Instruction &I) { 141 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n'); 142 SmallVector<Value *> Ops(I.operands()); 143 Worklist.remove(&I); 144 I.eraseFromParent(); 145 146 // Push remaining users and then the operand itself - allows further folds 147 // that were hindered by OneUse limits. 148 for (Value *Op : Ops) 149 if (auto *OpI = dyn_cast<Instruction>(Op)) { 150 Worklist.pushUsersToWorkList(*OpI); 151 Worklist.pushValue(OpI); 152 } 153 } 154 }; 155 } // namespace 156 157 /// Return the source operand of a potentially bitcasted value. If there is no 158 /// bitcast, return the input value itself. 159 static Value *peekThroughBitcasts(Value *V) { 160 while (auto *BitCast = dyn_cast<BitCastInst>(V)) 161 V = BitCast->getOperand(0); 162 return V; 163 } 164 165 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { 166 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. 167 // The widened load may load data from dirty regions or create data races 168 // non-existent in the source. 169 if (!Load || !Load->isSimple() || !Load->hasOneUse() || 170 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || 171 mustSuppressSpeculation(*Load)) 172 return false; 173 174 // We are potentially transforming byte-sized (8-bit) memory accesses, so make 175 // sure we have all of our type-based constraints in place for this target. 176 Type *ScalarTy = Load->getType()->getScalarType(); 177 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 178 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 179 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || 180 ScalarSize % 8 != 0) 181 return false; 182 183 return true; 184 } 185 186 bool VectorCombine::vectorizeLoadInsert(Instruction &I) { 187 // Match insert into fixed vector of scalar value. 188 // TODO: Handle non-zero insert index. 189 Value *Scalar; 190 if (!match(&I, 191 m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt()))) 192 return false; 193 194 // Optionally match an extract from another vector. 195 Value *X; 196 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); 197 if (!HasExtract) 198 X = Scalar; 199 200 auto *Load = dyn_cast<LoadInst>(X); 201 if (!canWidenLoad(Load, TTI)) 202 return false; 203 204 Type *ScalarTy = Scalar->getType(); 205 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 206 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 207 208 // Check safety of replacing the scalar load with a larger vector load. 209 // We use minimal alignment (maximum flexibility) because we only care about 210 // the dereferenceable region. When calculating cost and creating a new op, 211 // we may use a larger value based on alignment attributes. 212 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 213 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 214 215 unsigned MinVecNumElts = MinVectorSize / ScalarSize; 216 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); 217 unsigned OffsetEltIndex = 0; 218 Align Alignment = Load->getAlign(); 219 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 220 &DT)) { 221 // It is not safe to load directly from the pointer, but we can still peek 222 // through gep offsets and check if it safe to load from a base address with 223 // updated alignment. If it is, we can shuffle the element(s) into place 224 // after loading. 225 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType()); 226 APInt Offset(OffsetBitWidth, 0); 227 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset); 228 229 // We want to shuffle the result down from a high element of a vector, so 230 // the offset must be positive. 231 if (Offset.isNegative()) 232 return false; 233 234 // The offset must be a multiple of the scalar element to shuffle cleanly 235 // in the element's size. 236 uint64_t ScalarSizeInBytes = ScalarSize / 8; 237 if (Offset.urem(ScalarSizeInBytes) != 0) 238 return false; 239 240 // If we load MinVecNumElts, will our target element still be loaded? 241 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue(); 242 if (OffsetEltIndex >= MinVecNumElts) 243 return false; 244 245 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 246 &DT)) 247 return false; 248 249 // Update alignment with offset value. Note that the offset could be negated 250 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but 251 // negation does not change the result of the alignment calculation. 252 Alignment = commonAlignment(Alignment, Offset.getZExtValue()); 253 } 254 255 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 256 // Use the greater of the alignment on the load or its source pointer. 257 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 258 Type *LoadTy = Load->getType(); 259 unsigned AS = Load->getPointerAddressSpace(); 260 InstructionCost OldCost = 261 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 262 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); 263 OldCost += 264 TTI.getScalarizationOverhead(MinVecTy, DemandedElts, 265 /* Insert */ true, HasExtract, CostKind); 266 267 // New pattern: load VecPtr 268 InstructionCost NewCost = 269 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind); 270 // Optionally, we are shuffling the loaded vector element(s) into place. 271 // For the mask set everything but element 0 to undef to prevent poison from 272 // propagating from the extra loaded memory. This will also optionally 273 // shrink/grow the vector from the loaded size to the output size. 274 // We assume this operation has no cost in codegen if there was no offset. 275 // Note that we could use freeze to avoid poison problems, but then we might 276 // still need a shuffle to change the vector size. 277 auto *Ty = cast<FixedVectorType>(I.getType()); 278 unsigned OutputNumElts = Ty->getNumElements(); 279 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem); 280 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); 281 Mask[0] = OffsetEltIndex; 282 if (OffsetEltIndex) 283 NewCost += 284 TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind); 285 286 // We can aggressively convert to the vector form because the backend can 287 // invert this transform if it does not result in a performance win. 288 if (OldCost < NewCost || !NewCost.isValid()) 289 return false; 290 291 // It is safe and potentially profitable to load a vector directly: 292 // inselt undef, load Scalar, 0 --> load VecPtr 293 IRBuilder<> Builder(Load); 294 Value *CastedPtr = 295 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 296 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); 297 VecLd = Builder.CreateShuffleVector(VecLd, Mask); 298 299 replaceValue(I, *VecLd); 300 ++NumVecLoad; 301 return true; 302 } 303 304 /// If we are loading a vector and then inserting it into a larger vector with 305 /// undefined elements, try to load the larger vector and eliminate the insert. 306 /// This removes a shuffle in IR and may allow combining of other loaded values. 307 bool VectorCombine::widenSubvectorLoad(Instruction &I) { 308 // Match subvector insert of fixed vector. 309 auto *Shuf = cast<ShuffleVectorInst>(&I); 310 if (!Shuf->isIdentityWithPadding()) 311 return false; 312 313 // Allow a non-canonical shuffle mask that is choosing elements from op1. 314 unsigned NumOpElts = 315 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements(); 316 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) { 317 return M >= (int)(NumOpElts); 318 }); 319 320 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex)); 321 if (!canWidenLoad(Load, TTI)) 322 return false; 323 324 // We use minimal alignment (maximum flexibility) because we only care about 325 // the dereferenceable region. When calculating cost and creating a new op, 326 // we may use a larger value based on alignment attributes. 327 auto *Ty = cast<FixedVectorType>(I.getType()); 328 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 329 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 330 Align Alignment = Load->getAlign(); 331 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT)) 332 return false; 333 334 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 335 Type *LoadTy = Load->getType(); 336 unsigned AS = Load->getPointerAddressSpace(); 337 338 // Original pattern: insert_subvector (load PtrOp) 339 // This conservatively assumes that the cost of a subvector insert into an 340 // undef value is 0. We could add that cost if the cost model accurately 341 // reflects the real cost of that operation. 342 InstructionCost OldCost = 343 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 344 345 // New pattern: load PtrOp 346 InstructionCost NewCost = 347 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind); 348 349 // We can aggressively convert to the vector form because the backend can 350 // invert this transform if it does not result in a performance win. 351 if (OldCost < NewCost || !NewCost.isValid()) 352 return false; 353 354 IRBuilder<> Builder(Load); 355 Value *CastedPtr = 356 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 357 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment); 358 replaceValue(I, *VecLd); 359 ++NumVecLoad; 360 return true; 361 } 362 363 /// Determine which, if any, of the inputs should be replaced by a shuffle 364 /// followed by extract from a different index. 365 ExtractElementInst *VectorCombine::getShuffleExtract( 366 ExtractElementInst *Ext0, ExtractElementInst *Ext1, 367 unsigned PreferredExtractIndex = InvalidIndex) const { 368 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 369 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 370 assert(Index0C && Index1C && "Expected constant extract indexes"); 371 372 unsigned Index0 = Index0C->getZExtValue(); 373 unsigned Index1 = Index1C->getZExtValue(); 374 375 // If the extract indexes are identical, no shuffle is needed. 376 if (Index0 == Index1) 377 return nullptr; 378 379 Type *VecTy = Ext0->getVectorOperand()->getType(); 380 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); 381 InstructionCost Cost0 = 382 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 383 InstructionCost Cost1 = 384 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 385 386 // If both costs are invalid no shuffle is needed 387 if (!Cost0.isValid() && !Cost1.isValid()) 388 return nullptr; 389 390 // We are extracting from 2 different indexes, so one operand must be shuffled 391 // before performing a vector operation and/or extract. The more expensive 392 // extract will be replaced by a shuffle. 393 if (Cost0 > Cost1) 394 return Ext0; 395 if (Cost1 > Cost0) 396 return Ext1; 397 398 // If the costs are equal and there is a preferred extract index, shuffle the 399 // opposite operand. 400 if (PreferredExtractIndex == Index0) 401 return Ext1; 402 if (PreferredExtractIndex == Index1) 403 return Ext0; 404 405 // Otherwise, replace the extract with the higher index. 406 return Index0 > Index1 ? Ext0 : Ext1; 407 } 408 409 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 410 /// vector operation(s) followed by extract. Return true if the existing 411 /// instructions are cheaper than a vector alternative. Otherwise, return false 412 /// and if one of the extracts should be transformed to a shufflevector, set 413 /// \p ConvertToShuffle to that extract instruction. 414 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, 415 ExtractElementInst *Ext1, 416 const Instruction &I, 417 ExtractElementInst *&ConvertToShuffle, 418 unsigned PreferredExtractIndex) { 419 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 420 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 421 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes"); 422 423 unsigned Opcode = I.getOpcode(); 424 Value *Ext0Src = Ext0->getVectorOperand(); 425 Value *Ext1Src = Ext1->getVectorOperand(); 426 Type *ScalarTy = Ext0->getType(); 427 auto *VecTy = cast<VectorType>(Ext0Src->getType()); 428 InstructionCost ScalarOpCost, VectorOpCost; 429 430 // Get cost estimates for scalar and vector versions of the operation. 431 bool IsBinOp = Instruction::isBinaryOp(Opcode); 432 if (IsBinOp) { 433 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 434 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 435 } else { 436 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 437 "Expected a compare"); 438 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 439 ScalarOpCost = TTI.getCmpSelInstrCost( 440 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 441 VectorOpCost = TTI.getCmpSelInstrCost( 442 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 443 } 444 445 // Get cost estimates for the extract elements. These costs will factor into 446 // both sequences. 447 unsigned Ext0Index = Ext0IndexC->getZExtValue(); 448 unsigned Ext1Index = Ext1IndexC->getZExtValue(); 449 450 InstructionCost Extract0Cost = 451 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index); 452 InstructionCost Extract1Cost = 453 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index); 454 455 // A more expensive extract will always be replaced by a splat shuffle. 456 // For example, if Ext0 is more expensive: 457 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 458 // extelt (opcode (splat V0, Ext0), V1), Ext1 459 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 460 // check the cost of creating a broadcast shuffle and shuffling both 461 // operands to element 0. 462 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index; 463 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index; 464 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 465 466 // Extra uses of the extracts mean that we include those costs in the 467 // vector total because those instructions will not be eliminated. 468 InstructionCost OldCost, NewCost; 469 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) { 470 // Handle a special case. If the 2 extracts are identical, adjust the 471 // formulas to account for that. The extra use charge allows for either the 472 // CSE'd pattern or an unoptimized form with identical values: 473 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 474 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 475 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 476 OldCost = CheapExtractCost + ScalarOpCost; 477 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 478 } else { 479 // Handle the general case. Each extract is actually a different value: 480 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 481 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 482 NewCost = VectorOpCost + CheapExtractCost + 483 !Ext0->hasOneUse() * Extract0Cost + 484 !Ext1->hasOneUse() * Extract1Cost; 485 } 486 487 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); 488 if (ConvertToShuffle) { 489 if (IsBinOp && DisableBinopExtractShuffle) 490 return true; 491 492 // If we are extracting from 2 different indexes, then one operand must be 493 // shuffled before performing the vector operation. The shuffle mask is 494 // poison except for 1 lane that is being translated to the remaining 495 // extraction lane. Therefore, it is a splat shuffle. Ex: 496 // ShufMask = { poison, poison, 0, poison } 497 // TODO: The cost model has an option for a "broadcast" shuffle 498 // (splat-from-element-0), but no option for a more general splat. 499 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) { 500 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(), 501 PoisonMaskElem); 502 ShuffleMask[BestInsIndex] = BestExtIndex; 503 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 504 VecTy, ShuffleMask, CostKind, 0, nullptr, 505 {ConvertToShuffle}); 506 } else { 507 NewCost += 508 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy, 509 {}, CostKind, 0, nullptr, {ConvertToShuffle}); 510 } 511 } 512 513 // Aggressively form a vector op if the cost is equal because the transform 514 // may enable further optimization. 515 // Codegen can reverse this transform (scalarize) if it was not profitable. 516 return OldCost < NewCost; 517 } 518 519 /// Create a shuffle that translates (shifts) 1 element from the input vector 520 /// to a new element location. 521 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, 522 unsigned NewIndex, IRBuilder<> &Builder) { 523 // The shuffle mask is poison except for 1 lane that is being translated 524 // to the new element index. Example for OldIndex == 2 and NewIndex == 0: 525 // ShufMask = { 2, poison, poison, poison } 526 auto *VecTy = cast<FixedVectorType>(Vec->getType()); 527 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 528 ShufMask[NewIndex] = OldIndex; 529 return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); 530 } 531 532 /// Given an extract element instruction with constant index operand, shuffle 533 /// the source vector (shift the scalar element) to a NewIndex for extraction. 534 /// Return null if the input can be constant folded, so that we are not creating 535 /// unnecessary instructions. 536 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt, 537 unsigned NewIndex, 538 IRBuilder<> &Builder) { 539 // Shufflevectors can only be created for fixed-width vectors. 540 Value *X = ExtElt->getVectorOperand(); 541 if (!isa<FixedVectorType>(X->getType())) 542 return nullptr; 543 544 // If the extract can be constant-folded, this code is unsimplified. Defer 545 // to other passes to handle that. 546 Value *C = ExtElt->getIndexOperand(); 547 assert(isa<ConstantInt>(C) && "Expected a constant index operand"); 548 if (isa<Constant>(X)) 549 return nullptr; 550 551 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(), 552 NewIndex, Builder); 553 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex)); 554 } 555 556 /// Try to reduce extract element costs by converting scalar compares to vector 557 /// compares followed by extract. 558 /// cmp (ext0 V0, C), (ext1 V1, C) 559 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0, 560 ExtractElementInst *Ext1, Instruction &I) { 561 assert(isa<CmpInst>(&I) && "Expected a compare"); 562 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 563 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 564 "Expected matching constant extract indexes"); 565 566 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 567 ++NumVecCmp; 568 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 569 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 570 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1); 571 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand()); 572 replaceValue(I, *NewExt); 573 } 574 575 /// Try to reduce extract element costs by converting scalar binops to vector 576 /// binops followed by extract. 577 /// bo (ext0 V0, C), (ext1 V1, C) 578 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0, 579 ExtractElementInst *Ext1, Instruction &I) { 580 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 581 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 582 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 583 "Expected matching constant extract indexes"); 584 585 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 586 ++NumVecBO; 587 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 588 Value *VecBO = 589 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 590 591 // All IR flags are safe to back-propagate because any potential poison 592 // created in unused vector elements is discarded by the extract. 593 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 594 VecBOInst->copyIRFlags(&I); 595 596 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand()); 597 replaceValue(I, *NewExt); 598 } 599 600 /// Match an instruction with extracted vector operands. 601 bool VectorCombine::foldExtractExtract(Instruction &I) { 602 // It is not safe to transform things like div, urem, etc. because we may 603 // create undefined behavior when executing those on unknown vector elements. 604 if (!isSafeToSpeculativelyExecute(&I)) 605 return false; 606 607 Instruction *I0, *I1; 608 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 609 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) && 610 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1)))) 611 return false; 612 613 Value *V0, *V1; 614 uint64_t C0, C1; 615 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) || 616 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) || 617 V0->getType() != V1->getType()) 618 return false; 619 620 // If the scalar value 'I' is going to be re-inserted into a vector, then try 621 // to create an extract to that same element. The extract/insert can be 622 // reduced to a "select shuffle". 623 // TODO: If we add a larger pattern match that starts from an insert, this 624 // probably becomes unnecessary. 625 auto *Ext0 = cast<ExtractElementInst>(I0); 626 auto *Ext1 = cast<ExtractElementInst>(I1); 627 uint64_t InsertIndex = InvalidIndex; 628 if (I.hasOneUse()) 629 match(I.user_back(), 630 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); 631 632 ExtractElementInst *ExtractToChange; 633 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) 634 return false; 635 636 if (ExtractToChange) { 637 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0; 638 ExtractElementInst *NewExtract = 639 translateExtract(ExtractToChange, CheapExtractIdx, Builder); 640 if (!NewExtract) 641 return false; 642 if (ExtractToChange == Ext0) 643 Ext0 = NewExtract; 644 else 645 Ext1 = NewExtract; 646 } 647 648 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 649 foldExtExtCmp(Ext0, Ext1, I); 650 else 651 foldExtExtBinop(Ext0, Ext1, I); 652 653 Worklist.push(Ext0); 654 Worklist.push(Ext1); 655 return true; 656 } 657 658 /// Try to replace an extract + scalar fneg + insert with a vector fneg + 659 /// shuffle. 660 bool VectorCombine::foldInsExtFNeg(Instruction &I) { 661 // Match an insert (op (extract)) pattern. 662 Value *DestVec; 663 uint64_t Index; 664 Instruction *FNeg; 665 if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)), 666 m_ConstantInt(Index)))) 667 return false; 668 669 // Note: This handles the canonical fneg instruction and "fsub -0.0, X". 670 Value *SrcVec; 671 Instruction *Extract; 672 if (!match(FNeg, m_FNeg(m_CombineAnd( 673 m_Instruction(Extract), 674 m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index)))))) 675 return false; 676 677 auto *VecTy = cast<FixedVectorType>(I.getType()); 678 auto *ScalarTy = VecTy->getScalarType(); 679 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType()); 680 if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType()) 681 return false; 682 683 // Ignore bogus insert/extract index. 684 unsigned NumElts = VecTy->getNumElements(); 685 if (Index >= NumElts) 686 return false; 687 688 // We are inserting the negated element into the same lane that we extracted 689 // from. This is equivalent to a select-shuffle that chooses all but the 690 // negated element from the destination vector. 691 SmallVector<int> Mask(NumElts); 692 std::iota(Mask.begin(), Mask.end(), 0); 693 Mask[Index] = Index + NumElts; 694 InstructionCost OldCost = 695 TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) + 696 TTI.getVectorInstrCost(I, VecTy, CostKind, Index); 697 698 // If the extract has one use, it will be eliminated, so count it in the 699 // original cost. If it has more than one use, ignore the cost because it will 700 // be the same before/after. 701 if (Extract->hasOneUse()) 702 OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index); 703 704 InstructionCost NewCost = 705 TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) + 706 TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind); 707 708 bool NeedLenChg = SrcVecTy->getNumElements() != NumElts; 709 // If the lengths of the two vectors are not equal, 710 // we need to add a length-change vector. Add this cost. 711 SmallVector<int> SrcMask; 712 if (NeedLenChg) { 713 SrcMask.assign(NumElts, PoisonMaskElem); 714 SrcMask[Index] = Index; 715 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 716 SrcVecTy, SrcMask, CostKind); 717 } 718 719 if (NewCost > OldCost) 720 return false; 721 722 Value *NewShuf; 723 // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index 724 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg); 725 if (NeedLenChg) { 726 // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask 727 Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask); 728 NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask); 729 } else { 730 // shuffle DestVec, (fneg SrcVec), Mask 731 NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask); 732 } 733 734 replaceValue(I, *NewShuf); 735 return true; 736 } 737 738 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 739 /// destination type followed by shuffle. This can enable further transforms by 740 /// moving bitcasts or shuffles together. 741 bool VectorCombine::foldBitcastShuffle(Instruction &I) { 742 Value *V0, *V1; 743 ArrayRef<int> Mask; 744 if (!match(&I, m_BitCast(m_OneUse( 745 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask)))))) 746 return false; 747 748 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for 749 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle 750 // mask for scalable type is a splat or not. 751 // 2) Disallow non-vector casts. 752 // TODO: We could allow any shuffle. 753 auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); 754 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType()); 755 if (!DestTy || !SrcTy) 756 return false; 757 758 unsigned DestEltSize = DestTy->getScalarSizeInBits(); 759 unsigned SrcEltSize = SrcTy->getScalarSizeInBits(); 760 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0) 761 return false; 762 763 bool IsUnary = isa<UndefValue>(V1); 764 765 // For binary shuffles, only fold bitcast(shuffle(X,Y)) 766 // if it won't increase the number of bitcasts. 767 if (!IsUnary) { 768 auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType()); 769 auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType()); 770 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) && 771 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType())) 772 return false; 773 } 774 775 SmallVector<int, 16> NewMask; 776 if (DestEltSize <= SrcEltSize) { 777 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 778 // always be expanded to the equivalent form choosing narrower elements. 779 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask"); 780 unsigned ScaleFactor = SrcEltSize / DestEltSize; 781 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 782 } else { 783 // The bitcast is from narrow elements to wide elements. The shuffle mask 784 // must choose consecutive elements to allow casting first. 785 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask"); 786 unsigned ScaleFactor = DestEltSize / SrcEltSize; 787 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 788 return false; 789 } 790 791 // Bitcast the shuffle src - keep its original width but using the destination 792 // scalar type. 793 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; 794 auto *NewShuffleTy = 795 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); 796 auto *OldShuffleTy = 797 FixedVectorType::get(SrcTy->getScalarType(), Mask.size()); 798 unsigned NumOps = IsUnary ? 1 : 2; 799 800 // The new shuffle must not cost more than the old shuffle. 801 TargetTransformInfo::ShuffleKind SK = 802 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc 803 : TargetTransformInfo::SK_PermuteTwoSrc; 804 805 InstructionCost NewCost = 806 TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) + 807 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy, 808 TargetTransformInfo::CastContextHint::None, 809 CostKind)); 810 InstructionCost OldCost = 811 TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) + 812 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy, 813 TargetTransformInfo::CastContextHint::None, 814 CostKind); 815 816 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: " 817 << OldCost << " vs NewCost: " << NewCost << "\n"); 818 819 if (NewCost > OldCost || !NewCost.isValid()) 820 return false; 821 822 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC' 823 ++NumShufOfBitcast; 824 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy); 825 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy); 826 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask); 827 replaceValue(I, *Shuf); 828 return true; 829 } 830 831 /// VP Intrinsics whose vector operands are both splat values may be simplified 832 /// into the scalar version of the operation and the result splatted. This 833 /// can lead to scalarization down the line. 834 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { 835 if (!isa<VPIntrinsic>(I)) 836 return false; 837 VPIntrinsic &VPI = cast<VPIntrinsic>(I); 838 Value *Op0 = VPI.getArgOperand(0); 839 Value *Op1 = VPI.getArgOperand(1); 840 841 if (!isSplatValue(Op0) || !isSplatValue(Op1)) 842 return false; 843 844 // Check getSplatValue early in this function, to avoid doing unnecessary 845 // work. 846 Value *ScalarOp0 = getSplatValue(Op0); 847 Value *ScalarOp1 = getSplatValue(Op1); 848 if (!ScalarOp0 || !ScalarOp1) 849 return false; 850 851 // For the binary VP intrinsics supported here, the result on disabled lanes 852 // is a poison value. For now, only do this simplification if all lanes 853 // are active. 854 // TODO: Relax the condition that all lanes are active by using insertelement 855 // on inactive lanes. 856 auto IsAllTrueMask = [](Value *MaskVal) { 857 if (Value *SplattedVal = getSplatValue(MaskVal)) 858 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 859 return ConstValue->isAllOnesValue(); 860 return false; 861 }; 862 if (!IsAllTrueMask(VPI.getArgOperand(2))) 863 return false; 864 865 // Check to make sure we support scalarization of the intrinsic 866 Intrinsic::ID IntrID = VPI.getIntrinsicID(); 867 if (!VPBinOpIntrinsic::isVPBinOp(IntrID)) 868 return false; 869 870 // Calculate cost of splatting both operands into vectors and the vector 871 // intrinsic 872 VectorType *VecTy = cast<VectorType>(VPI.getType()); 873 SmallVector<int> Mask; 874 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy)) 875 Mask.resize(FVTy->getNumElements(), 0); 876 InstructionCost SplatCost = 877 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + 878 TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask, 879 CostKind); 880 881 // Calculate the cost of the VP Intrinsic 882 SmallVector<Type *, 4> Args; 883 for (Value *V : VPI.args()) 884 Args.push_back(V->getType()); 885 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args); 886 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 887 InstructionCost OldCost = 2 * SplatCost + VectorOpCost; 888 889 // Determine scalar opcode 890 std::optional<unsigned> FunctionalOpcode = 891 VPI.getFunctionalOpcode(); 892 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; 893 if (!FunctionalOpcode) { 894 ScalarIntrID = VPI.getFunctionalIntrinsicID(); 895 if (!ScalarIntrID) 896 return false; 897 } 898 899 // Calculate cost of scalarizing 900 InstructionCost ScalarOpCost = 0; 901 if (ScalarIntrID) { 902 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args); 903 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 904 } else { 905 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode, 906 VecTy->getScalarType(), CostKind); 907 } 908 909 // The existing splats may be kept around if other instructions use them. 910 InstructionCost CostToKeepSplats = 911 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); 912 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; 913 914 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI 915 << "\n"); 916 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost 917 << ", Cost of scalarizing:" << NewCost << "\n"); 918 919 // We want to scalarize unless the vector variant actually has lower cost. 920 if (OldCost < NewCost || !NewCost.isValid()) 921 return false; 922 923 // Scalarize the intrinsic 924 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); 925 Value *EVL = VPI.getArgOperand(3); 926 927 // If the VP op might introduce UB or poison, we can scalarize it provided 928 // that we know the EVL > 0: If the EVL is zero, then the original VP op 929 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by 930 // scalarizing it. 931 bool SafeToSpeculate; 932 if (ScalarIntrID) 933 SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID) 934 .hasFnAttr(Attribute::AttrKind::Speculatable); 935 else 936 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode( 937 *FunctionalOpcode, &VPI, nullptr, &AC, &DT); 938 if (!SafeToSpeculate && 939 !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI))) 940 return false; 941 942 Value *ScalarVal = 943 ScalarIntrID 944 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID, 945 {ScalarOp0, ScalarOp1}) 946 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode), 947 ScalarOp0, ScalarOp1); 948 949 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal)); 950 return true; 951 } 952 953 /// Match a vector binop or compare instruction with at least one inserted 954 /// scalar operand and convert to scalar binop/cmp followed by insertelement. 955 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { 956 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 957 Value *Ins0, *Ins1; 958 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && 959 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) 960 return false; 961 962 // Do not convert the vector condition of a vector select into a scalar 963 // condition. That may cause problems for codegen because of differences in 964 // boolean formats and register-file transfers. 965 // TODO: Can we account for that in the cost model? 966 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE; 967 if (IsCmp) 968 for (User *U : I.users()) 969 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) 970 return false; 971 972 // Match against one or both scalar values being inserted into constant 973 // vectors: 974 // vec_op VecC0, (inselt VecC1, V1, Index) 975 // vec_op (inselt VecC0, V0, Index), VecC1 976 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) 977 // TODO: Deal with mismatched index constants and variable indexes? 978 Constant *VecC0 = nullptr, *VecC1 = nullptr; 979 Value *V0 = nullptr, *V1 = nullptr; 980 uint64_t Index0 = 0, Index1 = 0; 981 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), 982 m_ConstantInt(Index0))) && 983 !match(Ins0, m_Constant(VecC0))) 984 return false; 985 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), 986 m_ConstantInt(Index1))) && 987 !match(Ins1, m_Constant(VecC1))) 988 return false; 989 990 bool IsConst0 = !V0; 991 bool IsConst1 = !V1; 992 if (IsConst0 && IsConst1) 993 return false; 994 if (!IsConst0 && !IsConst1 && Index0 != Index1) 995 return false; 996 997 auto *VecTy0 = cast<VectorType>(Ins0->getType()); 998 auto *VecTy1 = cast<VectorType>(Ins1->getType()); 999 if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || 1000 VecTy1->getElementCount().getKnownMinValue() <= Index1) 1001 return false; 1002 1003 // Bail for single insertion if it is a load. 1004 // TODO: Handle this once getVectorInstrCost can cost for load/stores. 1005 auto *I0 = dyn_cast_or_null<Instruction>(V0); 1006 auto *I1 = dyn_cast_or_null<Instruction>(V1); 1007 if ((IsConst0 && I1 && I1->mayReadFromMemory()) || 1008 (IsConst1 && I0 && I0->mayReadFromMemory())) 1009 return false; 1010 1011 uint64_t Index = IsConst0 ? Index1 : Index0; 1012 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); 1013 Type *VecTy = I.getType(); 1014 assert(VecTy->isVectorTy() && 1015 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && 1016 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || 1017 ScalarTy->isPointerTy()) && 1018 "Unexpected types for insert element into binop or cmp"); 1019 1020 unsigned Opcode = I.getOpcode(); 1021 InstructionCost ScalarOpCost, VectorOpCost; 1022 if (IsCmp) { 1023 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 1024 ScalarOpCost = TTI.getCmpSelInstrCost( 1025 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 1026 VectorOpCost = TTI.getCmpSelInstrCost( 1027 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 1028 } else { 1029 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 1030 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 1031 } 1032 1033 // Get cost estimate for the insert element. This cost will factor into 1034 // both sequences. 1035 InstructionCost InsertCost = TTI.getVectorInstrCost( 1036 Instruction::InsertElement, VecTy, CostKind, Index); 1037 InstructionCost OldCost = 1038 (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; 1039 InstructionCost NewCost = ScalarOpCost + InsertCost + 1040 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + 1041 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); 1042 1043 // We want to scalarize unless the vector variant actually has lower cost. 1044 if (OldCost < NewCost || !NewCost.isValid()) 1045 return false; 1046 1047 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 1048 // inselt NewVecC, (scalar_op V0, V1), Index 1049 if (IsCmp) 1050 ++NumScalarCmp; 1051 else 1052 ++NumScalarBO; 1053 1054 // For constant cases, extract the scalar element, this should constant fold. 1055 if (IsConst0) 1056 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); 1057 if (IsConst1) 1058 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); 1059 1060 Value *Scalar = 1061 IsCmp ? Builder.CreateCmp(Pred, V0, V1) 1062 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); 1063 1064 Scalar->setName(I.getName() + ".scalar"); 1065 1066 // All IR flags are safe to back-propagate. There is no potential for extra 1067 // poison to be created by the scalar instruction. 1068 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 1069 ScalarInst->copyIRFlags(&I); 1070 1071 // Fold the vector constants in the original vectors into a new base vector. 1072 Value *NewVecC = 1073 IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) 1074 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); 1075 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 1076 replaceValue(I, *Insert); 1077 return true; 1078 } 1079 1080 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of 1081 /// a vector into vector operations followed by extract. Note: The SLP pass 1082 /// may miss this pattern because of implementation problems. 1083 bool VectorCombine::foldExtractedCmps(Instruction &I) { 1084 auto *BI = dyn_cast<BinaryOperator>(&I); 1085 1086 // We are looking for a scalar binop of booleans. 1087 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) 1088 if (!BI || !I.getType()->isIntegerTy(1)) 1089 return false; 1090 1091 // The compare predicates should match, and each compare should have a 1092 // constant operand. 1093 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); 1094 Instruction *I0, *I1; 1095 Constant *C0, *C1; 1096 CmpPredicate P0, P1; 1097 // FIXME: Use CmpPredicate::getMatching here. 1098 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) || 1099 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || 1100 P0 != static_cast<CmpInst::Predicate>(P1)) 1101 return false; 1102 1103 // The compare operands must be extracts of the same vector with constant 1104 // extract indexes. 1105 Value *X; 1106 uint64_t Index0, Index1; 1107 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) || 1108 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))) 1109 return false; 1110 1111 auto *Ext0 = cast<ExtractElementInst>(I0); 1112 auto *Ext1 = cast<ExtractElementInst>(I1); 1113 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind); 1114 if (!ConvertToShuf) 1115 return false; 1116 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) && 1117 "Unknown ExtractElementInst"); 1118 1119 // The original scalar pattern is: 1120 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) 1121 CmpInst::Predicate Pred = P0; 1122 unsigned CmpOpcode = 1123 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp; 1124 auto *VecTy = dyn_cast<FixedVectorType>(X->getType()); 1125 if (!VecTy) 1126 return false; 1127 1128 InstructionCost Ext0Cost = 1129 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 1130 InstructionCost Ext1Cost = 1131 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 1132 InstructionCost CmpCost = TTI.getCmpSelInstrCost( 1133 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred, 1134 CostKind); 1135 1136 InstructionCost OldCost = 1137 Ext0Cost + Ext1Cost + CmpCost * 2 + 1138 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind); 1139 1140 // The proposed vector pattern is: 1141 // vcmp = cmp Pred X, VecC 1142 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 1143 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; 1144 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; 1145 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); 1146 InstructionCost NewCost = TTI.getCmpSelInstrCost( 1147 CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred, 1148 CostKind); 1149 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 1150 ShufMask[CheapIndex] = ExpensiveIndex; 1151 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, 1152 ShufMask, CostKind); 1153 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind); 1154 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex); 1155 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost; 1156 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost; 1157 1158 // Aggressively form vector ops if the cost is equal because the transform 1159 // may enable further optimization. 1160 // Codegen can reverse this transform (scalarize) if it was not profitable. 1161 if (OldCost < NewCost || !NewCost.isValid()) 1162 return false; 1163 1164 // Create a vector constant from the 2 scalar constants. 1165 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), 1166 PoisonValue::get(VecTy->getElementType())); 1167 CmpC[Index0] = C0; 1168 CmpC[Index1] = C1; 1169 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); 1170 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); 1171 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp; 1172 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf; 1173 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS); 1174 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); 1175 replaceValue(I, *NewExt); 1176 ++NumVecCmpBO; 1177 return true; 1178 } 1179 1180 // Check if memory loc modified between two instrs in the same BB 1181 static bool isMemModifiedBetween(BasicBlock::iterator Begin, 1182 BasicBlock::iterator End, 1183 const MemoryLocation &Loc, AAResults &AA) { 1184 unsigned NumScanned = 0; 1185 return std::any_of(Begin, End, [&](const Instruction &Instr) { 1186 return isModSet(AA.getModRefInfo(&Instr, Loc)) || 1187 ++NumScanned > MaxInstrsToScan; 1188 }); 1189 } 1190 1191 namespace { 1192 /// Helper class to indicate whether a vector index can be safely scalarized and 1193 /// if a freeze needs to be inserted. 1194 class ScalarizationResult { 1195 enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; 1196 1197 StatusTy Status; 1198 Value *ToFreeze; 1199 1200 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) 1201 : Status(Status), ToFreeze(ToFreeze) {} 1202 1203 public: 1204 ScalarizationResult(const ScalarizationResult &Other) = default; 1205 ~ScalarizationResult() { 1206 assert(!ToFreeze && "freeze() not called with ToFreeze being set"); 1207 } 1208 1209 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } 1210 static ScalarizationResult safe() { return {StatusTy::Safe}; } 1211 static ScalarizationResult safeWithFreeze(Value *ToFreeze) { 1212 return {StatusTy::SafeWithFreeze, ToFreeze}; 1213 } 1214 1215 /// Returns true if the index can be scalarize without requiring a freeze. 1216 bool isSafe() const { return Status == StatusTy::Safe; } 1217 /// Returns true if the index cannot be scalarized. 1218 bool isUnsafe() const { return Status == StatusTy::Unsafe; } 1219 /// Returns true if the index can be scalarize, but requires inserting a 1220 /// freeze. 1221 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } 1222 1223 /// Reset the state of Unsafe and clear ToFreze if set. 1224 void discard() { 1225 ToFreeze = nullptr; 1226 Status = StatusTy::Unsafe; 1227 } 1228 1229 /// Freeze the ToFreeze and update the use in \p User to use it. 1230 void freeze(IRBuilder<> &Builder, Instruction &UserI) { 1231 assert(isSafeWithFreeze() && 1232 "should only be used when freezing is required"); 1233 assert(is_contained(ToFreeze->users(), &UserI) && 1234 "UserI must be a user of ToFreeze"); 1235 IRBuilder<>::InsertPointGuard Guard(Builder); 1236 Builder.SetInsertPoint(cast<Instruction>(&UserI)); 1237 Value *Frozen = 1238 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); 1239 for (Use &U : make_early_inc_range((UserI.operands()))) 1240 if (U.get() == ToFreeze) 1241 U.set(Frozen); 1242 1243 ToFreeze = nullptr; 1244 } 1245 }; 1246 } // namespace 1247 1248 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p 1249 /// Idx. \p Idx must access a valid vector element. 1250 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, 1251 Instruction *CtxI, 1252 AssumptionCache &AC, 1253 const DominatorTree &DT) { 1254 // We do checks for both fixed vector types and scalable vector types. 1255 // This is the number of elements of fixed vector types, 1256 // or the minimum number of elements of scalable vector types. 1257 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue(); 1258 1259 if (auto *C = dyn_cast<ConstantInt>(Idx)) { 1260 if (C->getValue().ult(NumElements)) 1261 return ScalarizationResult::safe(); 1262 return ScalarizationResult::unsafe(); 1263 } 1264 1265 unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); 1266 APInt Zero(IntWidth, 0); 1267 APInt MaxElts(IntWidth, NumElements); 1268 ConstantRange ValidIndices(Zero, MaxElts); 1269 ConstantRange IdxRange(IntWidth, true); 1270 1271 if (isGuaranteedNotToBePoison(Idx, &AC)) { 1272 if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false, 1273 true, &AC, CtxI, &DT))) 1274 return ScalarizationResult::safe(); 1275 return ScalarizationResult::unsafe(); 1276 } 1277 1278 // If the index may be poison, check if we can insert a freeze before the 1279 // range of the index is restricted. 1280 Value *IdxBase; 1281 ConstantInt *CI; 1282 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { 1283 IdxRange = IdxRange.binaryAnd(CI->getValue()); 1284 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { 1285 IdxRange = IdxRange.urem(CI->getValue()); 1286 } 1287 1288 if (ValidIndices.contains(IdxRange)) 1289 return ScalarizationResult::safeWithFreeze(IdxBase); 1290 return ScalarizationResult::unsafe(); 1291 } 1292 1293 /// The memory operation on a vector of \p ScalarType had alignment of 1294 /// \p VectorAlignment. Compute the maximal, but conservatively correct, 1295 /// alignment that will be valid for the memory operation on a single scalar 1296 /// element of the same type with index \p Idx. 1297 static Align computeAlignmentAfterScalarization(Align VectorAlignment, 1298 Type *ScalarType, Value *Idx, 1299 const DataLayout &DL) { 1300 if (auto *C = dyn_cast<ConstantInt>(Idx)) 1301 return commonAlignment(VectorAlignment, 1302 C->getZExtValue() * DL.getTypeStoreSize(ScalarType)); 1303 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType)); 1304 } 1305 1306 // Combine patterns like: 1307 // %0 = load <4 x i32>, <4 x i32>* %a 1308 // %1 = insertelement <4 x i32> %0, i32 %b, i32 1 1309 // store <4 x i32> %1, <4 x i32>* %a 1310 // to: 1311 // %0 = bitcast <4 x i32>* %a to i32* 1312 // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 1313 // store i32 %b, i32* %1 1314 bool VectorCombine::foldSingleElementStore(Instruction &I) { 1315 auto *SI = cast<StoreInst>(&I); 1316 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) 1317 return false; 1318 1319 // TODO: Combine more complicated patterns (multiple insert) by referencing 1320 // TargetTransformInfo. 1321 Instruction *Source; 1322 Value *NewElement; 1323 Value *Idx; 1324 if (!match(SI->getValueOperand(), 1325 m_InsertElt(m_Instruction(Source), m_Value(NewElement), 1326 m_Value(Idx)))) 1327 return false; 1328 1329 if (auto *Load = dyn_cast<LoadInst>(Source)) { 1330 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType()); 1331 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); 1332 // Don't optimize for atomic/volatile load or store. Ensure memory is not 1333 // modified between, vector type matches store size, and index is inbounds. 1334 if (!Load->isSimple() || Load->getParent() != SI->getParent() || 1335 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || 1336 SrcAddr != SI->getPointerOperand()->stripPointerCasts()) 1337 return false; 1338 1339 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); 1340 if (ScalarizableIdx.isUnsafe() || 1341 isMemModifiedBetween(Load->getIterator(), SI->getIterator(), 1342 MemoryLocation::get(SI), AA)) 1343 return false; 1344 1345 if (ScalarizableIdx.isSafeWithFreeze()) 1346 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); 1347 Value *GEP = Builder.CreateInBoundsGEP( 1348 SI->getValueOperand()->getType(), SI->getPointerOperand(), 1349 {ConstantInt::get(Idx->getType(), 0), Idx}); 1350 StoreInst *NSI = Builder.CreateStore(NewElement, GEP); 1351 NSI->copyMetadata(*SI); 1352 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1353 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx, 1354 *DL); 1355 NSI->setAlignment(ScalarOpAlignment); 1356 replaceValue(I, *NSI); 1357 eraseInstruction(I); 1358 return true; 1359 } 1360 1361 return false; 1362 } 1363 1364 /// Try to scalarize vector loads feeding extractelement instructions. 1365 bool VectorCombine::scalarizeLoadExtract(Instruction &I) { 1366 Value *Ptr; 1367 if (!match(&I, m_Load(m_Value(Ptr)))) 1368 return false; 1369 1370 auto *VecTy = cast<VectorType>(I.getType()); 1371 auto *LI = cast<LoadInst>(&I); 1372 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) 1373 return false; 1374 1375 InstructionCost OriginalCost = 1376 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), 1377 LI->getPointerAddressSpace(), CostKind); 1378 InstructionCost ScalarizedCost = 0; 1379 1380 Instruction *LastCheckedInst = LI; 1381 unsigned NumInstChecked = 0; 1382 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze; 1383 auto FailureGuard = make_scope_exit([&]() { 1384 // If the transform is aborted, discard the ScalarizationResults. 1385 for (auto &Pair : NeedFreeze) 1386 Pair.second.discard(); 1387 }); 1388 1389 // Check if all users of the load are extracts with no memory modifications 1390 // between the load and the extract. Compute the cost of both the original 1391 // code and the scalarized version. 1392 for (User *U : LI->users()) { 1393 auto *UI = dyn_cast<ExtractElementInst>(U); 1394 if (!UI || UI->getParent() != LI->getParent()) 1395 return false; 1396 1397 // Check if any instruction between the load and the extract may modify 1398 // memory. 1399 if (LastCheckedInst->comesBefore(UI)) { 1400 for (Instruction &I : 1401 make_range(std::next(LI->getIterator()), UI->getIterator())) { 1402 // Bail out if we reached the check limit or the instruction may write 1403 // to memory. 1404 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory()) 1405 return false; 1406 NumInstChecked++; 1407 } 1408 LastCheckedInst = UI; 1409 } 1410 1411 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); 1412 if (ScalarIdx.isUnsafe()) 1413 return false; 1414 if (ScalarIdx.isSafeWithFreeze()) { 1415 NeedFreeze.try_emplace(UI, ScalarIdx); 1416 ScalarIdx.discard(); 1417 } 1418 1419 auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); 1420 OriginalCost += 1421 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, 1422 Index ? Index->getZExtValue() : -1); 1423 ScalarizedCost += 1424 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(), 1425 Align(1), LI->getPointerAddressSpace(), CostKind); 1426 ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType()); 1427 } 1428 1429 if (ScalarizedCost >= OriginalCost) 1430 return false; 1431 1432 // Replace extracts with narrow scalar loads. 1433 for (User *U : LI->users()) { 1434 auto *EI = cast<ExtractElementInst>(U); 1435 Value *Idx = EI->getOperand(1); 1436 1437 // Insert 'freeze' for poison indexes. 1438 auto It = NeedFreeze.find(EI); 1439 if (It != NeedFreeze.end()) 1440 It->second.freeze(Builder, *cast<Instruction>(Idx)); 1441 1442 Builder.SetInsertPoint(EI); 1443 Value *GEP = 1444 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx}); 1445 auto *NewLoad = cast<LoadInst>(Builder.CreateLoad( 1446 VecTy->getElementType(), GEP, EI->getName() + ".scalar")); 1447 1448 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1449 LI->getAlign(), VecTy->getElementType(), Idx, *DL); 1450 NewLoad->setAlignment(ScalarOpAlignment); 1451 1452 replaceValue(*EI, *NewLoad); 1453 } 1454 1455 FailureGuard.release(); 1456 return true; 1457 } 1458 1459 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" 1460 /// to "(bitcast (concat X, Y))" 1461 /// where X/Y are bitcasted from i1 mask vectors. 1462 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) { 1463 Type *Ty = I.getType(); 1464 if (!Ty->isIntegerTy()) 1465 return false; 1466 1467 // TODO: Add big endian test coverage 1468 if (DL->isBigEndian()) 1469 return false; 1470 1471 // Restrict to disjoint cases so the mask vectors aren't overlapping. 1472 Instruction *X, *Y; 1473 if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y)))) 1474 return false; 1475 1476 // Allow both sources to contain shl, to handle more generic pattern: 1477 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))" 1478 Value *SrcX; 1479 uint64_t ShAmtX = 0; 1480 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) && 1481 !match(X, m_OneUse( 1482 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))), 1483 m_ConstantInt(ShAmtX))))) 1484 return false; 1485 1486 Value *SrcY; 1487 uint64_t ShAmtY = 0; 1488 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) && 1489 !match(Y, m_OneUse( 1490 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))), 1491 m_ConstantInt(ShAmtY))))) 1492 return false; 1493 1494 // Canonicalize larger shift to the RHS. 1495 if (ShAmtX > ShAmtY) { 1496 std::swap(X, Y); 1497 std::swap(SrcX, SrcY); 1498 std::swap(ShAmtX, ShAmtY); 1499 } 1500 1501 // Ensure both sources are matching vXi1 bool mask types, and that the shift 1502 // difference is the mask width so they can be easily concatenated together. 1503 uint64_t ShAmtDiff = ShAmtY - ShAmtX; 1504 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0); 1505 unsigned BitWidth = Ty->getPrimitiveSizeInBits(); 1506 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType()); 1507 if (!MaskTy || SrcX->getType() != SrcY->getType() || 1508 !MaskTy->getElementType()->isIntegerTy(1) || 1509 MaskTy->getNumElements() != ShAmtDiff || 1510 MaskTy->getNumElements() > (BitWidth / 2)) 1511 return false; 1512 1513 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy); 1514 auto *ConcatIntTy = 1515 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements()); 1516 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff); 1517 1518 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements()); 1519 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 1520 1521 // TODO: Is it worth supporting multi use cases? 1522 InstructionCost OldCost = 0; 1523 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind); 1524 OldCost += 1525 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1526 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy, 1527 TTI::CastContextHint::None, CostKind); 1528 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy, 1529 TTI::CastContextHint::None, CostKind); 1530 1531 InstructionCost NewCost = 0; 1532 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy, 1533 ConcatMask, CostKind); 1534 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy, 1535 TTI::CastContextHint::None, CostKind); 1536 if (Ty != ConcatIntTy) 1537 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy, 1538 TTI::CastContextHint::None, CostKind); 1539 if (ShAmtX > 0) 1540 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1541 1542 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I 1543 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1544 << "\n"); 1545 1546 if (NewCost > OldCost) 1547 return false; 1548 1549 // Build bool mask concatenation, bitcast back to scalar integer, and perform 1550 // any residual zero-extension or shifting. 1551 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask); 1552 Worklist.pushValue(Concat); 1553 1554 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy); 1555 1556 if (Ty != ConcatIntTy) { 1557 Worklist.pushValue(Result); 1558 Result = Builder.CreateZExt(Result, Ty); 1559 } 1560 1561 if (ShAmtX > 0) { 1562 Worklist.pushValue(Result); 1563 Result = Builder.CreateShl(Result, ShAmtX); 1564 } 1565 1566 replaceValue(I, *Result); 1567 return true; 1568 } 1569 1570 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef" 1571 /// --> "binop (shuffle), (shuffle)". 1572 bool VectorCombine::foldPermuteOfBinops(Instruction &I) { 1573 BinaryOperator *BinOp; 1574 ArrayRef<int> OuterMask; 1575 if (!match(&I, 1576 m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask)))) 1577 return false; 1578 1579 // Don't introduce poison into div/rem. 1580 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem)) 1581 return false; 1582 1583 Value *Op00, *Op01; 1584 ArrayRef<int> Mask0; 1585 if (!match(BinOp->getOperand(0), 1586 m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0))))) 1587 return false; 1588 1589 Value *Op10, *Op11; 1590 ArrayRef<int> Mask1; 1591 if (!match(BinOp->getOperand(1), 1592 m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1))))) 1593 return false; 1594 1595 Instruction::BinaryOps Opcode = BinOp->getOpcode(); 1596 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1597 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType()); 1598 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType()); 1599 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType()); 1600 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty) 1601 return false; 1602 1603 unsigned NumSrcElts = BinOpTy->getNumElements(); 1604 1605 // Don't accept shuffles that reference the second operand in 1606 // div/rem or if its an undef arg. 1607 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) && 1608 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; })) 1609 return false; 1610 1611 // Merge outer / inner shuffles. 1612 SmallVector<int> NewMask0, NewMask1; 1613 for (int M : OuterMask) { 1614 if (M < 0 || M >= (int)NumSrcElts) { 1615 NewMask0.push_back(PoisonMaskElem); 1616 NewMask1.push_back(PoisonMaskElem); 1617 } else { 1618 NewMask0.push_back(Mask0[M]); 1619 NewMask1.push_back(Mask1[M]); 1620 } 1621 } 1622 1623 // Try to merge shuffles across the binop if the new shuffles are not costly. 1624 InstructionCost OldCost = 1625 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) + 1626 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy, 1627 OuterMask, CostKind, 0, nullptr, {BinOp}, &I) + 1628 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0, 1629 CostKind, 0, nullptr, {Op00, Op01}, 1630 cast<Instruction>(BinOp->getOperand(0))) + 1631 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1, 1632 CostKind, 0, nullptr, {Op10, Op11}, 1633 cast<Instruction>(BinOp->getOperand(1))); 1634 1635 InstructionCost NewCost = 1636 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0, 1637 CostKind, 0, nullptr, {Op00, Op01}) + 1638 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1, 1639 CostKind, 0, nullptr, {Op10, Op11}) + 1640 TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind); 1641 1642 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I 1643 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1644 << "\n"); 1645 1646 // If costs are equal, still fold as we reduce instruction count. 1647 if (NewCost > OldCost) 1648 return false; 1649 1650 Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0); 1651 Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1); 1652 Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); 1653 1654 // Intersect flags from the old binops. 1655 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) 1656 NewInst->copyIRFlags(BinOp); 1657 1658 Worklist.pushValue(Shuf0); 1659 Worklist.pushValue(Shuf1); 1660 replaceValue(I, *NewBO); 1661 return true; 1662 } 1663 1664 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)". 1665 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)". 1666 bool VectorCombine::foldShuffleOfBinops(Instruction &I) { 1667 ArrayRef<int> OldMask; 1668 Instruction *LHS, *RHS; 1669 if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)), 1670 m_OneUse(m_Instruction(RHS)), m_Mask(OldMask)))) 1671 return false; 1672 1673 // TODO: Add support for addlike etc. 1674 if (LHS->getOpcode() != RHS->getOpcode()) 1675 return false; 1676 1677 Value *X, *Y, *Z, *W; 1678 bool IsCommutative = false; 1679 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE; 1680 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE; 1681 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) && 1682 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) { 1683 auto *BO = cast<BinaryOperator>(LHS); 1684 // Don't introduce poison into div/rem. 1685 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem()) 1686 return false; 1687 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode()); 1688 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) && 1689 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) && 1690 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) { 1691 IsCommutative = cast<CmpInst>(LHS)->isCommutative(); 1692 } else 1693 return false; 1694 1695 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1696 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType()); 1697 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType()); 1698 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType()) 1699 return false; 1700 1701 unsigned NumSrcElts = BinOpTy->getNumElements(); 1702 1703 // If we have something like "add X, Y" and "add Z, X", swap ops to match. 1704 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z)) 1705 std::swap(X, Y); 1706 1707 auto ConvertToUnary = [NumSrcElts](int &M) { 1708 if (M >= (int)NumSrcElts) 1709 M -= NumSrcElts; 1710 }; 1711 1712 SmallVector<int> NewMask0(OldMask); 1713 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc; 1714 if (X == Z) { 1715 llvm::for_each(NewMask0, ConvertToUnary); 1716 SK0 = TargetTransformInfo::SK_PermuteSingleSrc; 1717 Z = PoisonValue::get(BinOpTy); 1718 } 1719 1720 SmallVector<int> NewMask1(OldMask); 1721 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc; 1722 if (Y == W) { 1723 llvm::for_each(NewMask1, ConvertToUnary); 1724 SK1 = TargetTransformInfo::SK_PermuteSingleSrc; 1725 W = PoisonValue::get(BinOpTy); 1726 } 1727 1728 // Try to replace a binop with a shuffle if the shuffle is not costly. 1729 InstructionCost OldCost = 1730 TTI.getInstructionCost(LHS, CostKind) + 1731 TTI.getInstructionCost(RHS, CostKind) + 1732 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy, 1733 OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I); 1734 1735 InstructionCost NewCost = 1736 TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) + 1737 TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}); 1738 1739 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) { 1740 NewCost += 1741 TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind); 1742 } else { 1743 auto *ShuffleCmpTy = 1744 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy); 1745 NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, 1746 ShuffleDstTy, PredLHS, CostKind); 1747 } 1748 1749 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I 1750 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1751 << "\n"); 1752 1753 // If either shuffle will constant fold away, then fold for the same cost as 1754 // we will reduce the instruction count. 1755 bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) || 1756 (isa<Constant>(Y) && isa<Constant>(W)); 1757 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost)) 1758 return false; 1759 1760 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0); 1761 Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1); 1762 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE 1763 ? Builder.CreateBinOp( 1764 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1) 1765 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1); 1766 1767 // Intersect flags from the old binops. 1768 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { 1769 NewInst->copyIRFlags(LHS); 1770 NewInst->andIRFlags(RHS); 1771 } 1772 1773 Worklist.pushValue(Shuf0); 1774 Worklist.pushValue(Shuf1); 1775 replaceValue(I, *NewBO); 1776 return true; 1777 } 1778 1779 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand 1780 /// into "castop (shuffle)". 1781 bool VectorCombine::foldShuffleOfCastops(Instruction &I) { 1782 Value *V0, *V1; 1783 ArrayRef<int> OldMask; 1784 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask)))) 1785 return false; 1786 1787 auto *C0 = dyn_cast<CastInst>(V0); 1788 auto *C1 = dyn_cast<CastInst>(V1); 1789 if (!C0 || !C1) 1790 return false; 1791 1792 Instruction::CastOps Opcode = C0->getOpcode(); 1793 if (C0->getSrcTy() != C1->getSrcTy()) 1794 return false; 1795 1796 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds. 1797 if (Opcode != C1->getOpcode()) { 1798 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value()))) 1799 Opcode = Instruction::SExt; 1800 else 1801 return false; 1802 } 1803 1804 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1805 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy()); 1806 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy()); 1807 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy) 1808 return false; 1809 1810 unsigned NumSrcElts = CastSrcTy->getNumElements(); 1811 unsigned NumDstElts = CastDstTy->getNumElements(); 1812 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) && 1813 "Only bitcasts expected to alter src/dst element counts"); 1814 1815 // Check for bitcasting of unscalable vector types. 1816 // e.g. <32 x i40> -> <40 x i32> 1817 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 && 1818 (NumDstElts % NumSrcElts) != 0) 1819 return false; 1820 1821 SmallVector<int, 16> NewMask; 1822 if (NumSrcElts >= NumDstElts) { 1823 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 1824 // always be expanded to the equivalent form choosing narrower elements. 1825 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask"); 1826 unsigned ScaleFactor = NumSrcElts / NumDstElts; 1827 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask); 1828 } else { 1829 // The bitcast is from narrow elements to wide elements. The shuffle mask 1830 // must choose consecutive elements to allow casting first. 1831 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask"); 1832 unsigned ScaleFactor = NumDstElts / NumSrcElts; 1833 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask)) 1834 return false; 1835 } 1836 1837 auto *NewShuffleDstTy = 1838 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size()); 1839 1840 // Try to replace a castop with a shuffle if the shuffle is not costly. 1841 InstructionCost CostC0 = 1842 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy, 1843 TTI::CastContextHint::None, CostKind); 1844 InstructionCost CostC1 = 1845 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy, 1846 TTI::CastContextHint::None, CostKind); 1847 InstructionCost OldCost = CostC0 + CostC1; 1848 OldCost += 1849 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, 1850 OldMask, CostKind, 0, nullptr, {}, &I); 1851 1852 InstructionCost NewCost = TTI.getShuffleCost( 1853 TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind); 1854 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy, 1855 TTI::CastContextHint::None, CostKind); 1856 if (!C0->hasOneUse()) 1857 NewCost += CostC0; 1858 if (!C1->hasOneUse()) 1859 NewCost += CostC1; 1860 1861 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I 1862 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1863 << "\n"); 1864 if (NewCost > OldCost) 1865 return false; 1866 1867 Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0), 1868 C1->getOperand(0), NewMask); 1869 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy); 1870 1871 // Intersect flags from the old casts. 1872 if (auto *NewInst = dyn_cast<Instruction>(Cast)) { 1873 NewInst->copyIRFlags(C0); 1874 NewInst->andIRFlags(C1); 1875 } 1876 1877 Worklist.pushValue(Shuf); 1878 replaceValue(I, *Cast); 1879 return true; 1880 } 1881 1882 /// Try to convert any of: 1883 /// "shuffle (shuffle x, y), (shuffle y, x)" 1884 /// "shuffle (shuffle x, undef), (shuffle y, undef)" 1885 /// "shuffle (shuffle x, undef), y" 1886 /// "shuffle x, (shuffle y, undef)" 1887 /// into "shuffle x, y". 1888 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { 1889 ArrayRef<int> OuterMask; 1890 Value *OuterV0, *OuterV1; 1891 if (!match(&I, 1892 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask)))) 1893 return false; 1894 1895 ArrayRef<int> InnerMask0, InnerMask1; 1896 Value *X0, *X1, *Y0, *Y1; 1897 bool Match0 = 1898 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0))); 1899 bool Match1 = 1900 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1))); 1901 if (!Match0 && !Match1) 1902 return false; 1903 1904 X0 = Match0 ? X0 : OuterV0; 1905 Y0 = Match0 ? Y0 : OuterV0; 1906 X1 = Match1 ? X1 : OuterV1; 1907 Y1 = Match1 ? Y1 : OuterV1; 1908 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1909 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType()); 1910 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType()); 1911 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy || 1912 X0->getType() != X1->getType()) 1913 return false; 1914 1915 unsigned NumSrcElts = ShuffleSrcTy->getNumElements(); 1916 unsigned NumImmElts = ShuffleImmTy->getNumElements(); 1917 1918 // Attempt to merge shuffles, matching upto 2 source operands. 1919 // Replace index to a poison arg with PoisonMaskElem. 1920 // Bail if either inner masks reference an undef arg. 1921 SmallVector<int, 16> NewMask(OuterMask); 1922 Value *NewX = nullptr, *NewY = nullptr; 1923 for (int &M : NewMask) { 1924 Value *Src = nullptr; 1925 if (0 <= M && M < (int)NumImmElts) { 1926 Src = OuterV0; 1927 if (Match0) { 1928 M = InnerMask0[M]; 1929 Src = M >= (int)NumSrcElts ? Y0 : X0; 1930 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; 1931 } 1932 } else if (M >= (int)NumImmElts) { 1933 Src = OuterV1; 1934 M -= NumImmElts; 1935 if (Match1) { 1936 M = InnerMask1[M]; 1937 Src = M >= (int)NumSrcElts ? Y1 : X1; 1938 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; 1939 } 1940 } 1941 if (Src && M != PoisonMaskElem) { 1942 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index"); 1943 if (isa<UndefValue>(Src)) { 1944 // We've referenced an undef element - if its poison, update the shuffle 1945 // mask, else bail. 1946 if (!isa<PoisonValue>(Src)) 1947 return false; 1948 M = PoisonMaskElem; 1949 continue; 1950 } 1951 if (!NewX || NewX == Src) { 1952 NewX = Src; 1953 continue; 1954 } 1955 if (!NewY || NewY == Src) { 1956 M += NumSrcElts; 1957 NewY = Src; 1958 continue; 1959 } 1960 return false; 1961 } 1962 } 1963 1964 if (!NewX) 1965 return PoisonValue::get(ShuffleDstTy); 1966 if (!NewY) 1967 NewY = PoisonValue::get(ShuffleSrcTy); 1968 1969 // Have we folded to an Identity shuffle? 1970 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) { 1971 replaceValue(I, *NewX); 1972 return true; 1973 } 1974 1975 // Try to merge the shuffles if the new shuffle is not costly. 1976 InstructionCost InnerCost0 = 0; 1977 if (Match0) 1978 InnerCost0 = TTI.getInstructionCost(cast<Instruction>(OuterV0), CostKind); 1979 1980 InstructionCost InnerCost1 = 0; 1981 if (Match1) 1982 InnerCost1 = TTI.getInstructionCost(cast<Instruction>(OuterV1), CostKind); 1983 1984 InstructionCost OuterCost = TTI.getShuffleCost( 1985 TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, 1986 0, nullptr, {OuterV0, OuterV1}, &I); 1987 1988 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; 1989 1990 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; }); 1991 TargetTransformInfo::ShuffleKind SK = 1992 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc 1993 : TargetTransformInfo::SK_PermuteTwoSrc; 1994 InstructionCost NewCost = TTI.getShuffleCost( 1995 SK, ShuffleSrcTy, NewMask, CostKind, 0, nullptr, {NewX, NewY}); 1996 if (!OuterV0->hasOneUse()) 1997 NewCost += InnerCost0; 1998 if (!OuterV1->hasOneUse()) 1999 NewCost += InnerCost1; 2000 2001 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I 2002 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 2003 << "\n"); 2004 if (NewCost > OldCost) 2005 return false; 2006 2007 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask); 2008 replaceValue(I, *Shuf); 2009 return true; 2010 } 2011 2012 /// Try to convert 2013 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)". 2014 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) { 2015 Value *V0, *V1; 2016 ArrayRef<int> OldMask; 2017 if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)), 2018 m_Mask(OldMask)))) 2019 return false; 2020 2021 auto *II0 = dyn_cast<IntrinsicInst>(V0); 2022 auto *II1 = dyn_cast<IntrinsicInst>(V1); 2023 if (!II0 || !II1) 2024 return false; 2025 2026 Intrinsic::ID IID = II0->getIntrinsicID(); 2027 if (IID != II1->getIntrinsicID()) 2028 return false; 2029 2030 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 2031 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType()); 2032 if (!ShuffleDstTy || !II0Ty) 2033 return false; 2034 2035 if (!isTriviallyVectorizable(IID)) 2036 return false; 2037 2038 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2039 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) && 2040 II0->getArgOperand(I) != II1->getArgOperand(I)) 2041 return false; 2042 2043 InstructionCost OldCost = 2044 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) + 2045 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) + 2046 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask, 2047 CostKind, 0, nullptr, {II0, II1}, &I); 2048 2049 SmallVector<Type *> NewArgsTy; 2050 InstructionCost NewCost = 0; 2051 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2052 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2053 NewArgsTy.push_back(II0->getArgOperand(I)->getType()); 2054 } else { 2055 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType()); 2056 NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(), 2057 VecTy->getNumElements() * 2)); 2058 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, 2059 VecTy, OldMask, CostKind); 2060 } 2061 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy); 2062 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind); 2063 2064 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I 2065 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 2066 << "\n"); 2067 2068 if (NewCost > OldCost) 2069 return false; 2070 2071 SmallVector<Value *> NewArgs; 2072 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2073 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2074 NewArgs.push_back(II0->getArgOperand(I)); 2075 } else { 2076 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), 2077 II1->getArgOperand(I), OldMask); 2078 NewArgs.push_back(Shuf); 2079 Worklist.pushValue(Shuf); 2080 } 2081 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs); 2082 2083 // Intersect flags from the old intrinsics. 2084 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) { 2085 NewInst->copyIRFlags(II0); 2086 NewInst->andIRFlags(II1); 2087 } 2088 2089 replaceValue(I, *NewIntrinsic); 2090 return true; 2091 } 2092 2093 using InstLane = std::pair<Use *, int>; 2094 2095 static InstLane lookThroughShuffles(Use *U, int Lane) { 2096 while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) { 2097 unsigned NumElts = 2098 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements(); 2099 int M = SV->getMaskValue(Lane); 2100 if (M < 0) 2101 return {nullptr, PoisonMaskElem}; 2102 if (static_cast<unsigned>(M) < NumElts) { 2103 U = &SV->getOperandUse(0); 2104 Lane = M; 2105 } else { 2106 U = &SV->getOperandUse(1); 2107 Lane = M - NumElts; 2108 } 2109 } 2110 return InstLane{U, Lane}; 2111 } 2112 2113 static SmallVector<InstLane> 2114 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { 2115 SmallVector<InstLane> NItem; 2116 for (InstLane IL : Item) { 2117 auto [U, Lane] = IL; 2118 InstLane OpLane = 2119 U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op), 2120 Lane) 2121 : InstLane{nullptr, PoisonMaskElem}; 2122 NItem.emplace_back(OpLane); 2123 } 2124 return NItem; 2125 } 2126 2127 /// Detect concat of multiple values into a vector 2128 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind, 2129 const TargetTransformInfo &TTI) { 2130 auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); 2131 unsigned NumElts = Ty->getNumElements(); 2132 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) 2133 return false; 2134 2135 // Check that the concat is free, usually meaning that the type will be split 2136 // during legalization. 2137 SmallVector<int, 16> ConcatMask(NumElts * 2); 2138 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 2139 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0) 2140 return false; 2141 2142 unsigned NumSlices = Item.size() / NumElts; 2143 // Currently we generate a tree of shuffles for the concats, which limits us 2144 // to a power2. 2145 if (!isPowerOf2_32(NumSlices)) 2146 return false; 2147 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { 2148 Use *SliceV = Item[Slice * NumElts].first; 2149 if (!SliceV || SliceV->get()->getType() != Ty) 2150 return false; 2151 for (unsigned Elt = 0; Elt < NumElts; ++Elt) { 2152 auto [V, Lane] = Item[Slice * NumElts + Elt]; 2153 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) 2154 return false; 2155 } 2156 } 2157 return true; 2158 } 2159 2160 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, 2161 const SmallPtrSet<Use *, 4> &IdentityLeafs, 2162 const SmallPtrSet<Use *, 4> &SplatLeafs, 2163 const SmallPtrSet<Use *, 4> &ConcatLeafs, 2164 IRBuilder<> &Builder, 2165 const TargetTransformInfo *TTI) { 2166 auto [FrontU, FrontLane] = Item.front(); 2167 2168 if (IdentityLeafs.contains(FrontU)) { 2169 return FrontU->get(); 2170 } 2171 if (SplatLeafs.contains(FrontU)) { 2172 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); 2173 return Builder.CreateShuffleVector(FrontU->get(), Mask); 2174 } 2175 if (ConcatLeafs.contains(FrontU)) { 2176 unsigned NumElts = 2177 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); 2178 SmallVector<Value *> Values(Item.size() / NumElts, nullptr); 2179 for (unsigned S = 0; S < Values.size(); ++S) 2180 Values[S] = Item[S * NumElts].first->get(); 2181 2182 while (Values.size() > 1) { 2183 NumElts *= 2; 2184 SmallVector<int, 16> Mask(NumElts, 0); 2185 std::iota(Mask.begin(), Mask.end(), 0); 2186 SmallVector<Value *> NewValues(Values.size() / 2, nullptr); 2187 for (unsigned S = 0; S < NewValues.size(); ++S) 2188 NewValues[S] = 2189 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); 2190 Values = NewValues; 2191 } 2192 return Values[0]; 2193 } 2194 2195 auto *I = cast<Instruction>(FrontU->get()); 2196 auto *II = dyn_cast<IntrinsicInst>(I); 2197 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0); 2198 SmallVector<Value *> Ops(NumOps); 2199 for (unsigned Idx = 0; Idx < NumOps; Idx++) { 2200 if (II && 2201 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) { 2202 Ops[Idx] = II->getOperand(Idx); 2203 continue; 2204 } 2205 Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), 2206 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs, 2207 Builder, TTI); 2208 } 2209 2210 SmallVector<Value *, 8> ValueList; 2211 for (const auto &Lane : Item) 2212 if (Lane.first) 2213 ValueList.push_back(Lane.first->get()); 2214 2215 Type *DstTy = 2216 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); 2217 if (auto *BI = dyn_cast<BinaryOperator>(I)) { 2218 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), 2219 Ops[0], Ops[1]); 2220 propagateIRFlags(Value, ValueList); 2221 return Value; 2222 } 2223 if (auto *CI = dyn_cast<CmpInst>(I)) { 2224 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]); 2225 propagateIRFlags(Value, ValueList); 2226 return Value; 2227 } 2228 if (auto *SI = dyn_cast<SelectInst>(I)) { 2229 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI); 2230 propagateIRFlags(Value, ValueList); 2231 return Value; 2232 } 2233 if (auto *CI = dyn_cast<CastInst>(I)) { 2234 auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(), 2235 Ops[0], DstTy); 2236 propagateIRFlags(Value, ValueList); 2237 return Value; 2238 } 2239 if (II) { 2240 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops); 2241 propagateIRFlags(Value, ValueList); 2242 return Value; 2243 } 2244 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate"); 2245 auto *Value = 2246 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]); 2247 propagateIRFlags(Value, ValueList); 2248 return Value; 2249 } 2250 2251 // Starting from a shuffle, look up through operands tracking the shuffled index 2252 // of each lane. If we can simplify away the shuffles to identities then 2253 // do so. 2254 bool VectorCombine::foldShuffleToIdentity(Instruction &I) { 2255 auto *Ty = dyn_cast<FixedVectorType>(I.getType()); 2256 if (!Ty || I.use_empty()) 2257 return false; 2258 2259 SmallVector<InstLane> Start(Ty->getNumElements()); 2260 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M) 2261 Start[M] = lookThroughShuffles(&*I.use_begin(), M); 2262 2263 SmallVector<SmallVector<InstLane>> Worklist; 2264 Worklist.push_back(Start); 2265 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; 2266 unsigned NumVisited = 0; 2267 2268 while (!Worklist.empty()) { 2269 if (++NumVisited > MaxInstrsToScan) 2270 return false; 2271 2272 SmallVector<InstLane> Item = Worklist.pop_back_val(); 2273 auto [FrontU, FrontLane] = Item.front(); 2274 2275 // If we found an undef first lane then bail out to keep things simple. 2276 if (!FrontU) 2277 return false; 2278 2279 // Helper to peek through bitcasts to the same value. 2280 auto IsEquiv = [&](Value *X, Value *Y) { 2281 return X->getType() == Y->getType() && 2282 peekThroughBitcasts(X) == peekThroughBitcasts(Y); 2283 }; 2284 2285 // Look for an identity value. 2286 if (FrontLane == 0 && 2287 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() == 2288 Ty->getNumElements() && 2289 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) { 2290 Value *FrontV = Item.front().first->get(); 2291 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) && 2292 E.value().second == (int)E.index()); 2293 })) { 2294 IdentityLeafs.insert(FrontU); 2295 continue; 2296 } 2297 // Look for constants, for the moment only supporting constant splats. 2298 if (auto *C = dyn_cast<Constant>(FrontU); 2299 C && C->getSplatValue() && 2300 all_of(drop_begin(Item), [Item](InstLane &IL) { 2301 Value *FrontV = Item.front().first->get(); 2302 Use *U = IL.first; 2303 return !U || (isa<Constant>(U->get()) && 2304 cast<Constant>(U->get())->getSplatValue() == 2305 cast<Constant>(FrontV)->getSplatValue()); 2306 })) { 2307 SplatLeafs.insert(FrontU); 2308 continue; 2309 } 2310 // Look for a splat value. 2311 if (all_of(drop_begin(Item), [Item](InstLane &IL) { 2312 auto [FrontU, FrontLane] = Item.front(); 2313 auto [U, Lane] = IL; 2314 return !U || (U->get() == FrontU->get() && Lane == FrontLane); 2315 })) { 2316 SplatLeafs.insert(FrontU); 2317 continue; 2318 } 2319 2320 // We need each element to be the same type of value, and check that each 2321 // element has a single use. 2322 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) { 2323 Value *FrontV = Item.front().first->get(); 2324 if (!IL.first) 2325 return true; 2326 Value *V = IL.first->get(); 2327 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse()) 2328 return false; 2329 if (V->getValueID() != FrontV->getValueID()) 2330 return false; 2331 if (auto *CI = dyn_cast<CmpInst>(V)) 2332 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate()) 2333 return false; 2334 if (auto *CI = dyn_cast<CastInst>(V)) 2335 if (CI->getSrcTy()->getScalarType() != 2336 cast<CastInst>(FrontV)->getSrcTy()->getScalarType()) 2337 return false; 2338 if (auto *SI = dyn_cast<SelectInst>(V)) 2339 if (!isa<VectorType>(SI->getOperand(0)->getType()) || 2340 SI->getOperand(0)->getType() != 2341 cast<SelectInst>(FrontV)->getOperand(0)->getType()) 2342 return false; 2343 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V)) 2344 return false; 2345 auto *II = dyn_cast<IntrinsicInst>(V); 2346 return !II || (isa<IntrinsicInst>(FrontV) && 2347 II->getIntrinsicID() == 2348 cast<IntrinsicInst>(FrontV)->getIntrinsicID() && 2349 !II->hasOperandBundles()); 2350 }; 2351 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) { 2352 // Check the operator is one that we support. 2353 if (isa<BinaryOperator, CmpInst>(FrontU)) { 2354 // We exclude div/rem in case they hit UB from poison lanes. 2355 if (auto *BO = dyn_cast<BinaryOperator>(FrontU); 2356 BO && BO->isIntDivRem()) 2357 return false; 2358 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2359 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2360 continue; 2361 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst, 2362 FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) { 2363 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2364 continue; 2365 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) { 2366 // TODO: Handle vector widening/narrowing bitcasts. 2367 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy()); 2368 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy()); 2369 if (DstTy && SrcTy && 2370 SrcTy->getNumElements() == DstTy->getNumElements()) { 2371 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2372 continue; 2373 } 2374 } else if (isa<SelectInst>(FrontU)) { 2375 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2376 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2377 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); 2378 continue; 2379 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); 2380 II && isTriviallyVectorizable(II->getIntrinsicID()) && 2381 !II->hasOperandBundles()) { 2382 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { 2383 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op, 2384 &TTI)) { 2385 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { 2386 Value *FrontV = Item.front().first->get(); 2387 Use *U = IL.first; 2388 return !U || (cast<Instruction>(U->get())->getOperand(Op) == 2389 cast<Instruction>(FrontV)->getOperand(Op)); 2390 })) 2391 return false; 2392 continue; 2393 } 2394 Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); 2395 } 2396 continue; 2397 } 2398 } 2399 2400 if (isFreeConcat(Item, CostKind, TTI)) { 2401 ConcatLeafs.insert(FrontU); 2402 continue; 2403 } 2404 2405 return false; 2406 } 2407 2408 if (NumVisited <= 1) 2409 return false; 2410 2411 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n"); 2412 2413 // If we got this far, we know the shuffles are superfluous and can be 2414 // removed. Scan through again and generate the new tree of instructions. 2415 Builder.SetInsertPoint(&I); 2416 Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, 2417 ConcatLeafs, Builder, &TTI); 2418 replaceValue(I, *V); 2419 return true; 2420 } 2421 2422 /// Given a commutative reduction, the order of the input lanes does not alter 2423 /// the results. We can use this to remove certain shuffles feeding the 2424 /// reduction, removing the need to shuffle at all. 2425 bool VectorCombine::foldShuffleFromReductions(Instruction &I) { 2426 auto *II = dyn_cast<IntrinsicInst>(&I); 2427 if (!II) 2428 return false; 2429 switch (II->getIntrinsicID()) { 2430 case Intrinsic::vector_reduce_add: 2431 case Intrinsic::vector_reduce_mul: 2432 case Intrinsic::vector_reduce_and: 2433 case Intrinsic::vector_reduce_or: 2434 case Intrinsic::vector_reduce_xor: 2435 case Intrinsic::vector_reduce_smin: 2436 case Intrinsic::vector_reduce_smax: 2437 case Intrinsic::vector_reduce_umin: 2438 case Intrinsic::vector_reduce_umax: 2439 break; 2440 default: 2441 return false; 2442 } 2443 2444 // Find all the inputs when looking through operations that do not alter the 2445 // lane order (binops, for example). Currently we look for a single shuffle, 2446 // and can ignore splat values. 2447 std::queue<Value *> Worklist; 2448 SmallPtrSet<Value *, 4> Visited; 2449 ShuffleVectorInst *Shuffle = nullptr; 2450 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0))) 2451 Worklist.push(Op); 2452 2453 while (!Worklist.empty()) { 2454 Value *CV = Worklist.front(); 2455 Worklist.pop(); 2456 if (Visited.contains(CV)) 2457 continue; 2458 2459 // Splats don't change the order, so can be safely ignored. 2460 if (isSplatValue(CV)) 2461 continue; 2462 2463 Visited.insert(CV); 2464 2465 if (auto *CI = dyn_cast<Instruction>(CV)) { 2466 if (CI->isBinaryOp()) { 2467 for (auto *Op : CI->operand_values()) 2468 Worklist.push(Op); 2469 continue; 2470 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) { 2471 if (Shuffle && Shuffle != SV) 2472 return false; 2473 Shuffle = SV; 2474 continue; 2475 } 2476 } 2477 2478 // Anything else is currently an unknown node. 2479 return false; 2480 } 2481 2482 if (!Shuffle) 2483 return false; 2484 2485 // Check all uses of the binary ops and shuffles are also included in the 2486 // lane-invariant operations (Visited should be the list of lanewise 2487 // instructions, including the shuffle that we found). 2488 for (auto *V : Visited) 2489 for (auto *U : V->users()) 2490 if (!Visited.contains(U) && U != &I) 2491 return false; 2492 2493 FixedVectorType *VecType = 2494 dyn_cast<FixedVectorType>(II->getOperand(0)->getType()); 2495 if (!VecType) 2496 return false; 2497 FixedVectorType *ShuffleInputType = 2498 dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType()); 2499 if (!ShuffleInputType) 2500 return false; 2501 unsigned NumInputElts = ShuffleInputType->getNumElements(); 2502 2503 // Find the mask from sorting the lanes into order. This is most likely to 2504 // become a identity or concat mask. Undef elements are pushed to the end. 2505 SmallVector<int> ConcatMask; 2506 Shuffle->getShuffleMask(ConcatMask); 2507 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; }); 2508 // In the case of a truncating shuffle it's possible for the mask 2509 // to have an index greater than the size of the resulting vector. 2510 // This requires special handling. 2511 bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts; 2512 bool UsesSecondVec = 2513 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; }); 2514 2515 FixedVectorType *VecTyForCost = 2516 (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType; 2517 InstructionCost OldCost = TTI.getShuffleCost( 2518 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2519 VecTyForCost, Shuffle->getShuffleMask(), CostKind); 2520 InstructionCost NewCost = TTI.getShuffleCost( 2521 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2522 VecTyForCost, ConcatMask, CostKind); 2523 2524 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle 2525 << "\n"); 2526 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost 2527 << "\n"); 2528 if (NewCost < OldCost) { 2529 Builder.SetInsertPoint(Shuffle); 2530 Value *NewShuffle = Builder.CreateShuffleVector( 2531 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); 2532 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); 2533 replaceValue(*Shuffle, *NewShuffle); 2534 } 2535 2536 // See if we can re-use foldSelectShuffle, getting it to reduce the size of 2537 // the shuffle into a nicer order, as it can ignore the order of the shuffles. 2538 return foldSelectShuffle(*Shuffle, true); 2539 } 2540 2541 /// Determine if its more efficient to fold: 2542 /// reduce(trunc(x)) -> trunc(reduce(x)). 2543 /// reduce(sext(x)) -> sext(reduce(x)). 2544 /// reduce(zext(x)) -> zext(reduce(x)). 2545 bool VectorCombine::foldCastFromReductions(Instruction &I) { 2546 auto *II = dyn_cast<IntrinsicInst>(&I); 2547 if (!II) 2548 return false; 2549 2550 bool TruncOnly = false; 2551 Intrinsic::ID IID = II->getIntrinsicID(); 2552 switch (IID) { 2553 case Intrinsic::vector_reduce_add: 2554 case Intrinsic::vector_reduce_mul: 2555 TruncOnly = true; 2556 break; 2557 case Intrinsic::vector_reduce_and: 2558 case Intrinsic::vector_reduce_or: 2559 case Intrinsic::vector_reduce_xor: 2560 break; 2561 default: 2562 return false; 2563 } 2564 2565 unsigned ReductionOpc = getArithmeticReductionInstruction(IID); 2566 Value *ReductionSrc = I.getOperand(0); 2567 2568 Value *Src; 2569 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) && 2570 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src)))))) 2571 return false; 2572 2573 auto CastOpc = 2574 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode(); 2575 2576 auto *SrcTy = cast<VectorType>(Src->getType()); 2577 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType()); 2578 Type *ResultTy = I.getType(); 2579 2580 InstructionCost OldCost = TTI.getArithmeticReductionCost( 2581 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind); 2582 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy, 2583 TTI::CastContextHint::None, CostKind, 2584 cast<CastInst>(ReductionSrc)); 2585 InstructionCost NewCost = 2586 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt, 2587 CostKind) + 2588 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(), 2589 TTI::CastContextHint::None, CostKind); 2590 2591 if (OldCost <= NewCost || !NewCost.isValid()) 2592 return false; 2593 2594 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(), 2595 II->getIntrinsicID(), {Src}); 2596 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy); 2597 replaceValue(I, *NewCast); 2598 return true; 2599 } 2600 2601 /// This method looks for groups of shuffles acting on binops, of the form: 2602 /// %x = shuffle ... 2603 /// %y = shuffle ... 2604 /// %a = binop %x, %y 2605 /// %b = binop %x, %y 2606 /// shuffle %a, %b, selectmask 2607 /// We may, especially if the shuffle is wider than legal, be able to convert 2608 /// the shuffle to a form where only parts of a and b need to be computed. On 2609 /// architectures with no obvious "select" shuffle, this can reduce the total 2610 /// number of operations if the target reports them as cheaper. 2611 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { 2612 auto *SVI = cast<ShuffleVectorInst>(&I); 2613 auto *VT = cast<FixedVectorType>(I.getType()); 2614 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0)); 2615 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1)); 2616 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() || 2617 VT != Op0->getType()) 2618 return false; 2619 2620 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0)); 2621 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1)); 2622 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0)); 2623 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1)); 2624 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B}); 2625 auto checkSVNonOpUses = [&](Instruction *I) { 2626 if (!I || I->getOperand(0)->getType() != VT) 2627 return true; 2628 return any_of(I->users(), [&](User *U) { 2629 return U != Op0 && U != Op1 && 2630 !(isa<ShuffleVectorInst>(U) && 2631 (InputShuffles.contains(cast<Instruction>(U)) || 2632 isInstructionTriviallyDead(cast<Instruction>(U)))); 2633 }); 2634 }; 2635 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) || 2636 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B)) 2637 return false; 2638 2639 // Collect all the uses that are shuffles that we can transform together. We 2640 // may not have a single shuffle, but a group that can all be transformed 2641 // together profitably. 2642 SmallVector<ShuffleVectorInst *> Shuffles; 2643 auto collectShuffles = [&](Instruction *I) { 2644 for (auto *U : I->users()) { 2645 auto *SV = dyn_cast<ShuffleVectorInst>(U); 2646 if (!SV || SV->getType() != VT) 2647 return false; 2648 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) || 2649 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1)) 2650 return false; 2651 if (!llvm::is_contained(Shuffles, SV)) 2652 Shuffles.push_back(SV); 2653 } 2654 return true; 2655 }; 2656 if (!collectShuffles(Op0) || !collectShuffles(Op1)) 2657 return false; 2658 // From a reduction, we need to be processing a single shuffle, otherwise the 2659 // other uses will not be lane-invariant. 2660 if (FromReduction && Shuffles.size() > 1) 2661 return false; 2662 2663 // Add any shuffle uses for the shuffles we have found, to include them in our 2664 // cost calculations. 2665 if (!FromReduction) { 2666 for (ShuffleVectorInst *SV : Shuffles) { 2667 for (auto *U : SV->users()) { 2668 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U); 2669 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT) 2670 Shuffles.push_back(SSV); 2671 } 2672 } 2673 } 2674 2675 // For each of the output shuffles, we try to sort all the first vector 2676 // elements to the beginning, followed by the second array elements at the 2677 // end. If the binops are legalized to smaller vectors, this may reduce total 2678 // number of binops. We compute the ReconstructMask mask needed to convert 2679 // back to the original lane order. 2680 SmallVector<std::pair<int, int>> V1, V2; 2681 SmallVector<SmallVector<int>> OrigReconstructMasks; 2682 int MaxV1Elt = 0, MaxV2Elt = 0; 2683 unsigned NumElts = VT->getNumElements(); 2684 for (ShuffleVectorInst *SVN : Shuffles) { 2685 SmallVector<int> Mask; 2686 SVN->getShuffleMask(Mask); 2687 2688 // Check the operands are the same as the original, or reversed (in which 2689 // case we need to commute the mask). 2690 Value *SVOp0 = SVN->getOperand(0); 2691 Value *SVOp1 = SVN->getOperand(1); 2692 if (isa<UndefValue>(SVOp1)) { 2693 auto *SSV = cast<ShuffleVectorInst>(SVOp0); 2694 SVOp0 = SSV->getOperand(0); 2695 SVOp1 = SSV->getOperand(1); 2696 for (unsigned I = 0, E = Mask.size(); I != E; I++) { 2697 if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size())) 2698 return false; 2699 Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]); 2700 } 2701 } 2702 if (SVOp0 == Op1 && SVOp1 == Op0) { 2703 std::swap(SVOp0, SVOp1); 2704 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 2705 } 2706 if (SVOp0 != Op0 || SVOp1 != Op1) 2707 return false; 2708 2709 // Calculate the reconstruction mask for this shuffle, as the mask needed to 2710 // take the packed values from Op0/Op1 and reconstructing to the original 2711 // order. 2712 SmallVector<int> ReconstructMask; 2713 for (unsigned I = 0; I < Mask.size(); I++) { 2714 if (Mask[I] < 0) { 2715 ReconstructMask.push_back(-1); 2716 } else if (Mask[I] < static_cast<int>(NumElts)) { 2717 MaxV1Elt = std::max(MaxV1Elt, Mask[I]); 2718 auto It = find_if(V1, [&](const std::pair<int, int> &A) { 2719 return Mask[I] == A.first; 2720 }); 2721 if (It != V1.end()) 2722 ReconstructMask.push_back(It - V1.begin()); 2723 else { 2724 ReconstructMask.push_back(V1.size()); 2725 V1.emplace_back(Mask[I], V1.size()); 2726 } 2727 } else { 2728 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts); 2729 auto It = find_if(V2, [&](const std::pair<int, int> &A) { 2730 return Mask[I] - static_cast<int>(NumElts) == A.first; 2731 }); 2732 if (It != V2.end()) 2733 ReconstructMask.push_back(NumElts + It - V2.begin()); 2734 else { 2735 ReconstructMask.push_back(NumElts + V2.size()); 2736 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size()); 2737 } 2738 } 2739 } 2740 2741 // For reductions, we know that the lane ordering out doesn't alter the 2742 // result. In-order can help simplify the shuffle away. 2743 if (FromReduction) 2744 sort(ReconstructMask); 2745 OrigReconstructMasks.push_back(std::move(ReconstructMask)); 2746 } 2747 2748 // If the Maximum element used from V1 and V2 are not larger than the new 2749 // vectors, the vectors are already packes and performing the optimization 2750 // again will likely not help any further. This also prevents us from getting 2751 // stuck in a cycle in case the costs do not also rule it out. 2752 if (V1.empty() || V2.empty() || 2753 (MaxV1Elt == static_cast<int>(V1.size()) - 1 && 2754 MaxV2Elt == static_cast<int>(V2.size()) - 1)) 2755 return false; 2756 2757 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a 2758 // shuffle of another shuffle, or not a shuffle (that is treated like a 2759 // identity shuffle). 2760 auto GetBaseMaskValue = [&](Instruction *I, int M) { 2761 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2762 if (!SV) 2763 return M; 2764 if (isa<UndefValue>(SV->getOperand(1))) 2765 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2766 if (InputShuffles.contains(SSV)) 2767 return SSV->getMaskValue(SV->getMaskValue(M)); 2768 return SV->getMaskValue(M); 2769 }; 2770 2771 // Attempt to sort the inputs my ascending mask values to make simpler input 2772 // shuffles and push complex shuffles down to the uses. We sort on the first 2773 // of the two input shuffle orders, to try and get at least one input into a 2774 // nice order. 2775 auto SortBase = [&](Instruction *A, std::pair<int, int> X, 2776 std::pair<int, int> Y) { 2777 int MXA = GetBaseMaskValue(A, X.first); 2778 int MYA = GetBaseMaskValue(A, Y.first); 2779 return MXA < MYA; 2780 }; 2781 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) { 2782 return SortBase(SVI0A, A, B); 2783 }); 2784 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) { 2785 return SortBase(SVI1A, A, B); 2786 }); 2787 // Calculate our ReconstructMasks from the OrigReconstructMasks and the 2788 // modified order of the input shuffles. 2789 SmallVector<SmallVector<int>> ReconstructMasks; 2790 for (const auto &Mask : OrigReconstructMasks) { 2791 SmallVector<int> ReconstructMask; 2792 for (int M : Mask) { 2793 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) { 2794 auto It = find_if(V, [M](auto A) { return A.second == M; }); 2795 assert(It != V.end() && "Expected all entries in Mask"); 2796 return std::distance(V.begin(), It); 2797 }; 2798 if (M < 0) 2799 ReconstructMask.push_back(-1); 2800 else if (M < static_cast<int>(NumElts)) { 2801 ReconstructMask.push_back(FindIndex(V1, M)); 2802 } else { 2803 ReconstructMask.push_back(NumElts + FindIndex(V2, M)); 2804 } 2805 } 2806 ReconstructMasks.push_back(std::move(ReconstructMask)); 2807 } 2808 2809 // Calculate the masks needed for the new input shuffles, which get padded 2810 // with undef 2811 SmallVector<int> V1A, V1B, V2A, V2B; 2812 for (unsigned I = 0; I < V1.size(); I++) { 2813 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first)); 2814 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first)); 2815 } 2816 for (unsigned I = 0; I < V2.size(); I++) { 2817 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first)); 2818 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first)); 2819 } 2820 while (V1A.size() < NumElts) { 2821 V1A.push_back(PoisonMaskElem); 2822 V1B.push_back(PoisonMaskElem); 2823 } 2824 while (V2A.size() < NumElts) { 2825 V2A.push_back(PoisonMaskElem); 2826 V2B.push_back(PoisonMaskElem); 2827 } 2828 2829 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) { 2830 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2831 if (!SV) 2832 return C; 2833 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1)) 2834 ? TTI::SK_PermuteSingleSrc 2835 : TTI::SK_PermuteTwoSrc, 2836 VT, SV->getShuffleMask(), CostKind); 2837 }; 2838 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) { 2839 return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind); 2840 }; 2841 2842 // Get the costs of the shuffles + binops before and after with the new 2843 // shuffle masks. 2844 InstructionCost CostBefore = 2845 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) + 2846 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind); 2847 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(), 2848 InstructionCost(0), AddShuffleCost); 2849 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(), 2850 InstructionCost(0), AddShuffleCost); 2851 2852 // The new binops will be unused for lanes past the used shuffle lengths. 2853 // These types attempt to get the correct cost for that from the target. 2854 FixedVectorType *Op0SmallVT = 2855 FixedVectorType::get(VT->getScalarType(), V1.size()); 2856 FixedVectorType *Op1SmallVT = 2857 FixedVectorType::get(VT->getScalarType(), V2.size()); 2858 InstructionCost CostAfter = 2859 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) + 2860 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind); 2861 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(), 2862 InstructionCost(0), AddShuffleMaskCost); 2863 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B}); 2864 CostAfter += 2865 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(), 2866 InstructionCost(0), AddShuffleMaskCost); 2867 2868 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n"); 2869 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore 2870 << " vs CostAfter: " << CostAfter << "\n"); 2871 if (CostBefore <= CostAfter) 2872 return false; 2873 2874 // The cost model has passed, create the new instructions. 2875 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * { 2876 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2877 if (!SV) 2878 return I; 2879 if (isa<UndefValue>(SV->getOperand(1))) 2880 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2881 if (InputShuffles.contains(SSV)) 2882 return SSV->getOperand(Op); 2883 return SV->getOperand(Op); 2884 }; 2885 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef()); 2886 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0), 2887 GetShuffleOperand(SVI0A, 1), V1A); 2888 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef()); 2889 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0), 2890 GetShuffleOperand(SVI0B, 1), V1B); 2891 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef()); 2892 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0), 2893 GetShuffleOperand(SVI1A, 1), V2A); 2894 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef()); 2895 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0), 2896 GetShuffleOperand(SVI1B, 1), V2B); 2897 Builder.SetInsertPoint(Op0); 2898 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(), 2899 NSV0A, NSV0B); 2900 if (auto *I = dyn_cast<Instruction>(NOp0)) 2901 I->copyIRFlags(Op0, true); 2902 Builder.SetInsertPoint(Op1); 2903 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(), 2904 NSV1A, NSV1B); 2905 if (auto *I = dyn_cast<Instruction>(NOp1)) 2906 I->copyIRFlags(Op1, true); 2907 2908 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) { 2909 Builder.SetInsertPoint(Shuffles[S]); 2910 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]); 2911 replaceValue(*Shuffles[S], *NSV); 2912 } 2913 2914 Worklist.pushValue(NSV0A); 2915 Worklist.pushValue(NSV0B); 2916 Worklist.pushValue(NSV1A); 2917 Worklist.pushValue(NSV1B); 2918 for (auto *S : Shuffles) 2919 Worklist.add(S); 2920 return true; 2921 } 2922 2923 /// Check if instruction depends on ZExt and this ZExt can be moved after the 2924 /// instruction. Move ZExt if it is profitable. For example: 2925 /// logic(zext(x),y) -> zext(logic(x,trunc(y))) 2926 /// lshr((zext(x),y) -> zext(lshr(x,trunc(y))) 2927 /// Cost model calculations takes into account if zext(x) has other users and 2928 /// whether it can be propagated through them too. 2929 bool VectorCombine::shrinkType(Instruction &I) { 2930 Value *ZExted, *OtherOperand; 2931 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)), 2932 m_Value(OtherOperand))) && 2933 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) 2934 return false; 2935 2936 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0); 2937 2938 auto *BigTy = cast<FixedVectorType>(I.getType()); 2939 auto *SmallTy = cast<FixedVectorType>(ZExted->getType()); 2940 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits(); 2941 2942 if (I.getOpcode() == Instruction::LShr) { 2943 // Check that the shift amount is less than the number of bits in the 2944 // smaller type. Otherwise, the smaller lshr will return a poison value. 2945 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL); 2946 if (ShAmtKB.getMaxValue().uge(BW)) 2947 return false; 2948 } else { 2949 // Check that the expression overall uses at most the same number of bits as 2950 // ZExted 2951 KnownBits KB = computeKnownBits(&I, *DL); 2952 if (KB.countMaxActiveBits() > BW) 2953 return false; 2954 } 2955 2956 // Calculate costs of leaving current IR as it is and moving ZExt operation 2957 // later, along with adding truncates if needed 2958 InstructionCost ZExtCost = TTI.getCastInstrCost( 2959 Instruction::ZExt, BigTy, SmallTy, 2960 TargetTransformInfo::CastContextHint::None, CostKind); 2961 InstructionCost CurrentCost = ZExtCost; 2962 InstructionCost ShrinkCost = 0; 2963 2964 // Calculate total cost and check that we can propagate through all ZExt users 2965 for (User *U : ZExtOperand->users()) { 2966 auto *UI = cast<Instruction>(U); 2967 if (UI == &I) { 2968 CurrentCost += 2969 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2970 ShrinkCost += 2971 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2972 ShrinkCost += ZExtCost; 2973 continue; 2974 } 2975 2976 if (!Instruction::isBinaryOp(UI->getOpcode())) 2977 return false; 2978 2979 // Check if we can propagate ZExt through its other users 2980 KnownBits KB = computeKnownBits(UI, *DL); 2981 if (KB.countMaxActiveBits() > BW) 2982 return false; 2983 2984 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2985 ShrinkCost += 2986 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2987 ShrinkCost += ZExtCost; 2988 } 2989 2990 // If the other instruction operand is not a constant, we'll need to 2991 // generate a truncate instruction. So we have to adjust cost 2992 if (!isa<Constant>(OtherOperand)) 2993 ShrinkCost += TTI.getCastInstrCost( 2994 Instruction::Trunc, SmallTy, BigTy, 2995 TargetTransformInfo::CastContextHint::None, CostKind); 2996 2997 // If the cost of shrinking types and leaving the IR is the same, we'll lean 2998 // towards modifying the IR because shrinking opens opportunities for other 2999 // shrinking optimisations. 3000 if (ShrinkCost > CurrentCost) 3001 return false; 3002 3003 Builder.SetInsertPoint(&I); 3004 Value *Op0 = ZExted; 3005 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy); 3006 // Keep the order of operands the same 3007 if (I.getOperand(0) == OtherOperand) 3008 std::swap(Op0, Op1); 3009 Value *NewBinOp = 3010 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1); 3011 cast<Instruction>(NewBinOp)->copyIRFlags(&I); 3012 cast<Instruction>(NewBinOp)->copyMetadata(I); 3013 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy); 3014 replaceValue(I, *NewZExtr); 3015 return true; 3016 } 3017 3018 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) --> 3019 /// shuffle (DstVec, SrcVec, Mask) 3020 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) { 3021 Value *DstVec, *SrcVec; 3022 uint64_t ExtIdx, InsIdx; 3023 if (!match(&I, 3024 m_InsertElt(m_Value(DstVec), 3025 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)), 3026 m_ConstantInt(InsIdx)))) 3027 return false; 3028 3029 auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); 3030 if (!VecTy || SrcVec->getType() != VecTy) 3031 return false; 3032 3033 unsigned NumElts = VecTy->getNumElements(); 3034 if (ExtIdx >= NumElts || InsIdx >= NumElts) 3035 return false; 3036 3037 // Insertion into poison is a cheaper single operand shuffle. 3038 TargetTransformInfo::ShuffleKind SK; 3039 SmallVector<int> Mask(NumElts, PoisonMaskElem); 3040 if (isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec)) { 3041 SK = TargetTransformInfo::SK_PermuteSingleSrc; 3042 Mask[InsIdx] = ExtIdx; 3043 std::swap(DstVec, SrcVec); 3044 } else { 3045 SK = TargetTransformInfo::SK_PermuteTwoSrc; 3046 std::iota(Mask.begin(), Mask.end(), 0); 3047 Mask[InsIdx] = ExtIdx + NumElts; 3048 } 3049 3050 // Cost 3051 auto *Ins = cast<InsertElementInst>(&I); 3052 auto *Ext = cast<ExtractElementInst>(I.getOperand(1)); 3053 InstructionCost InsCost = 3054 TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx); 3055 InstructionCost ExtCost = 3056 TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx); 3057 InstructionCost OldCost = ExtCost + InsCost; 3058 3059 InstructionCost NewCost = TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0, 3060 nullptr, {DstVec, SrcVec}); 3061 if (!Ext->hasOneUse()) 3062 NewCost += ExtCost; 3063 3064 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I 3065 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 3066 << "\n"); 3067 3068 if (OldCost < NewCost) 3069 return false; 3070 3071 // Canonicalize undef param to RHS to help further folds. 3072 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) { 3073 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 3074 std::swap(DstVec, SrcVec); 3075 } 3076 3077 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask); 3078 replaceValue(I, *Shuf); 3079 3080 return true; 3081 } 3082 3083 /// This is the entry point for all transforms. Pass manager differences are 3084 /// handled in the callers of this function. 3085 bool VectorCombine::run() { 3086 if (DisableVectorCombine) 3087 return false; 3088 3089 // Don't attempt vectorization if the target does not support vectors. 3090 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true))) 3091 return false; 3092 3093 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n"); 3094 3095 bool MadeChange = false; 3096 auto FoldInst = [this, &MadeChange](Instruction &I) { 3097 Builder.SetInsertPoint(&I); 3098 bool IsVectorType = isa<VectorType>(I.getType()); 3099 bool IsFixedVectorType = isa<FixedVectorType>(I.getType()); 3100 auto Opcode = I.getOpcode(); 3101 3102 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n'); 3103 3104 // These folds should be beneficial regardless of when this pass is run 3105 // in the optimization pipeline. 3106 // The type checking is for run-time efficiency. We can avoid wasting time 3107 // dispatching to folding functions if there's no chance of matching. 3108 if (IsFixedVectorType) { 3109 switch (Opcode) { 3110 case Instruction::InsertElement: 3111 MadeChange |= vectorizeLoadInsert(I); 3112 break; 3113 case Instruction::ShuffleVector: 3114 MadeChange |= widenSubvectorLoad(I); 3115 break; 3116 default: 3117 break; 3118 } 3119 } 3120 3121 // This transform works with scalable and fixed vectors 3122 // TODO: Identify and allow other scalable transforms 3123 if (IsVectorType) { 3124 MadeChange |= scalarizeBinopOrCmp(I); 3125 MadeChange |= scalarizeLoadExtract(I); 3126 MadeChange |= scalarizeVPIntrinsic(I); 3127 } 3128 3129 if (Opcode == Instruction::Store) 3130 MadeChange |= foldSingleElementStore(I); 3131 3132 // If this is an early pipeline invocation of this pass, we are done. 3133 if (TryEarlyFoldsOnly) 3134 return; 3135 3136 // Otherwise, try folds that improve codegen but may interfere with 3137 // early IR canonicalizations. 3138 // The type checking is for run-time efficiency. We can avoid wasting time 3139 // dispatching to folding functions if there's no chance of matching. 3140 if (IsFixedVectorType) { 3141 switch (Opcode) { 3142 case Instruction::InsertElement: 3143 MadeChange |= foldInsExtFNeg(I); 3144 MadeChange |= foldInsExtVectorToShuffle(I); 3145 break; 3146 case Instruction::ShuffleVector: 3147 MadeChange |= foldPermuteOfBinops(I); 3148 MadeChange |= foldShuffleOfBinops(I); 3149 MadeChange |= foldShuffleOfCastops(I); 3150 MadeChange |= foldShuffleOfShuffles(I); 3151 MadeChange |= foldShuffleOfIntrinsics(I); 3152 MadeChange |= foldSelectShuffle(I); 3153 MadeChange |= foldShuffleToIdentity(I); 3154 break; 3155 case Instruction::BitCast: 3156 MadeChange |= foldBitcastShuffle(I); 3157 break; 3158 default: 3159 MadeChange |= shrinkType(I); 3160 break; 3161 } 3162 } else { 3163 switch (Opcode) { 3164 case Instruction::Call: 3165 MadeChange |= foldShuffleFromReductions(I); 3166 MadeChange |= foldCastFromReductions(I); 3167 break; 3168 case Instruction::ICmp: 3169 case Instruction::FCmp: 3170 MadeChange |= foldExtractExtract(I); 3171 break; 3172 case Instruction::Or: 3173 MadeChange |= foldConcatOfBoolMasks(I); 3174 [[fallthrough]]; 3175 default: 3176 if (Instruction::isBinaryOp(Opcode)) { 3177 MadeChange |= foldExtractExtract(I); 3178 MadeChange |= foldExtractedCmps(I); 3179 } 3180 break; 3181 } 3182 } 3183 }; 3184 3185 for (BasicBlock &BB : F) { 3186 // Ignore unreachable basic blocks. 3187 if (!DT.isReachableFromEntry(&BB)) 3188 continue; 3189 // Use early increment range so that we can erase instructions in loop. 3190 for (Instruction &I : make_early_inc_range(BB)) { 3191 if (I.isDebugOrPseudoInst()) 3192 continue; 3193 FoldInst(I); 3194 } 3195 } 3196 3197 while (!Worklist.isEmpty()) { 3198 Instruction *I = Worklist.removeOne(); 3199 if (!I) 3200 continue; 3201 3202 if (isInstructionTriviallyDead(I)) { 3203 eraseInstruction(*I); 3204 continue; 3205 } 3206 3207 FoldInst(*I); 3208 } 3209 3210 return MadeChange; 3211 } 3212 3213 PreservedAnalyses VectorCombinePass::run(Function &F, 3214 FunctionAnalysisManager &FAM) { 3215 auto &AC = FAM.getResult<AssumptionAnalysis>(F); 3216 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 3217 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 3218 AAResults &AA = FAM.getResult<AAManager>(F); 3219 const DataLayout *DL = &F.getDataLayout(); 3220 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput, 3221 TryEarlyFoldsOnly); 3222 if (!Combiner.run()) 3223 return PreservedAnalyses::all(); 3224 PreservedAnalyses PA; 3225 PA.preserveSet<CFGAnalyses>(); 3226 return PA; 3227 } 3228