1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/AssumptionCache.h" 21 #include "llvm/Analysis/BasicAliasAnalysis.h" 22 #include "llvm/Analysis/GlobalsModRef.h" 23 #include "llvm/Analysis/Loads.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/ValueTracking.h" 26 #include "llvm/Analysis/VectorUtils.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/PatternMatch.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Transforms/Utils/Local.h" 33 #include "llvm/Transforms/Utils/LoopUtils.h" 34 #include <numeric> 35 #include <queue> 36 37 #define DEBUG_TYPE "vector-combine" 38 #include "llvm/Transforms/Utils/InstructionWorklist.h" 39 40 using namespace llvm; 41 using namespace llvm::PatternMatch; 42 43 STATISTIC(NumVecLoad, "Number of vector loads formed"); 44 STATISTIC(NumVecCmp, "Number of vector compares formed"); 45 STATISTIC(NumVecBO, "Number of vector binops formed"); 46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); 47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); 48 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 49 STATISTIC(NumScalarCmp, "Number of scalar compares formed"); 50 51 static cl::opt<bool> DisableVectorCombine( 52 "disable-vector-combine", cl::init(false), cl::Hidden, 53 cl::desc("Disable all vector combine transforms")); 54 55 static cl::opt<bool> DisableBinopExtractShuffle( 56 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 57 cl::desc("Disable binop extract to shuffle transforms")); 58 59 static cl::opt<unsigned> MaxInstrsToScan( 60 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, 61 cl::desc("Max number of instructions to scan for vector combining.")); 62 63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max(); 64 65 namespace { 66 class VectorCombine { 67 public: 68 VectorCombine(Function &F, const TargetTransformInfo &TTI, 69 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, 70 const DataLayout *DL, TTI::TargetCostKind CostKind, 71 bool TryEarlyFoldsOnly) 72 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL), 73 CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {} 74 75 bool run(); 76 77 private: 78 Function &F; 79 IRBuilder<> Builder; 80 const TargetTransformInfo &TTI; 81 const DominatorTree &DT; 82 AAResults &AA; 83 AssumptionCache &AC; 84 const DataLayout *DL; 85 TTI::TargetCostKind CostKind; 86 87 /// If true, only perform beneficial early IR transforms. Do not introduce new 88 /// vector operations. 89 bool TryEarlyFoldsOnly; 90 91 InstructionWorklist Worklist; 92 93 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction" 94 // parameter. That should be updated to specific sub-classes because the 95 // run loop was changed to dispatch on opcode. 96 bool vectorizeLoadInsert(Instruction &I); 97 bool widenSubvectorLoad(Instruction &I); 98 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, 99 ExtractElementInst *Ext1, 100 unsigned PreferredExtractIndex) const; 101 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 102 const Instruction &I, 103 ExtractElementInst *&ConvertToShuffle, 104 unsigned PreferredExtractIndex); 105 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 106 Instruction &I); 107 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 108 Instruction &I); 109 bool foldExtractExtract(Instruction &I); 110 bool foldInsExtFNeg(Instruction &I); 111 bool foldInsExtVectorToShuffle(Instruction &I); 112 bool foldBitcastShuffle(Instruction &I); 113 bool scalarizeBinopOrCmp(Instruction &I); 114 bool scalarizeVPIntrinsic(Instruction &I); 115 bool foldExtractedCmps(Instruction &I); 116 bool foldSingleElementStore(Instruction &I); 117 bool scalarizeLoadExtract(Instruction &I); 118 bool foldConcatOfBoolMasks(Instruction &I); 119 bool foldPermuteOfBinops(Instruction &I); 120 bool foldShuffleOfBinops(Instruction &I); 121 bool foldShuffleOfCastops(Instruction &I); 122 bool foldShuffleOfShuffles(Instruction &I); 123 bool foldShuffleOfIntrinsics(Instruction &I); 124 bool foldShuffleToIdentity(Instruction &I); 125 bool foldShuffleFromReductions(Instruction &I); 126 bool foldCastFromReductions(Instruction &I); 127 bool foldSelectShuffle(Instruction &I, bool FromReduction = false); 128 bool shrinkType(Instruction &I); 129 130 void replaceValue(Value &Old, Value &New) { 131 Old.replaceAllUsesWith(&New); 132 if (auto *NewI = dyn_cast<Instruction>(&New)) { 133 New.takeName(&Old); 134 Worklist.pushUsersToWorkList(*NewI); 135 Worklist.pushValue(NewI); 136 } 137 Worklist.pushValue(&Old); 138 } 139 140 void eraseInstruction(Instruction &I) { 141 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n'); 142 for (Value *Op : I.operands()) 143 Worklist.pushValue(Op); 144 Worklist.remove(&I); 145 I.eraseFromParent(); 146 } 147 }; 148 } // namespace 149 150 /// Return the source operand of a potentially bitcasted value. If there is no 151 /// bitcast, return the input value itself. 152 static Value *peekThroughBitcasts(Value *V) { 153 while (auto *BitCast = dyn_cast<BitCastInst>(V)) 154 V = BitCast->getOperand(0); 155 return V; 156 } 157 158 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { 159 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. 160 // The widened load may load data from dirty regions or create data races 161 // non-existent in the source. 162 if (!Load || !Load->isSimple() || !Load->hasOneUse() || 163 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || 164 mustSuppressSpeculation(*Load)) 165 return false; 166 167 // We are potentially transforming byte-sized (8-bit) memory accesses, so make 168 // sure we have all of our type-based constraints in place for this target. 169 Type *ScalarTy = Load->getType()->getScalarType(); 170 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 171 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 172 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || 173 ScalarSize % 8 != 0) 174 return false; 175 176 return true; 177 } 178 179 bool VectorCombine::vectorizeLoadInsert(Instruction &I) { 180 // Match insert into fixed vector of scalar value. 181 // TODO: Handle non-zero insert index. 182 Value *Scalar; 183 if (!match(&I, 184 m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt()))) 185 return false; 186 187 // Optionally match an extract from another vector. 188 Value *X; 189 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); 190 if (!HasExtract) 191 X = Scalar; 192 193 auto *Load = dyn_cast<LoadInst>(X); 194 if (!canWidenLoad(Load, TTI)) 195 return false; 196 197 Type *ScalarTy = Scalar->getType(); 198 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 199 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 200 201 // Check safety of replacing the scalar load with a larger vector load. 202 // We use minimal alignment (maximum flexibility) because we only care about 203 // the dereferenceable region. When calculating cost and creating a new op, 204 // we may use a larger value based on alignment attributes. 205 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 206 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 207 208 unsigned MinVecNumElts = MinVectorSize / ScalarSize; 209 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); 210 unsigned OffsetEltIndex = 0; 211 Align Alignment = Load->getAlign(); 212 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 213 &DT)) { 214 // It is not safe to load directly from the pointer, but we can still peek 215 // through gep offsets and check if it safe to load from a base address with 216 // updated alignment. If it is, we can shuffle the element(s) into place 217 // after loading. 218 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType()); 219 APInt Offset(OffsetBitWidth, 0); 220 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset); 221 222 // We want to shuffle the result down from a high element of a vector, so 223 // the offset must be positive. 224 if (Offset.isNegative()) 225 return false; 226 227 // The offset must be a multiple of the scalar element to shuffle cleanly 228 // in the element's size. 229 uint64_t ScalarSizeInBytes = ScalarSize / 8; 230 if (Offset.urem(ScalarSizeInBytes) != 0) 231 return false; 232 233 // If we load MinVecNumElts, will our target element still be loaded? 234 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue(); 235 if (OffsetEltIndex >= MinVecNumElts) 236 return false; 237 238 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 239 &DT)) 240 return false; 241 242 // Update alignment with offset value. Note that the offset could be negated 243 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but 244 // negation does not change the result of the alignment calculation. 245 Alignment = commonAlignment(Alignment, Offset.getZExtValue()); 246 } 247 248 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 249 // Use the greater of the alignment on the load or its source pointer. 250 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 251 Type *LoadTy = Load->getType(); 252 unsigned AS = Load->getPointerAddressSpace(); 253 InstructionCost OldCost = 254 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 255 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); 256 OldCost += 257 TTI.getScalarizationOverhead(MinVecTy, DemandedElts, 258 /* Insert */ true, HasExtract, CostKind); 259 260 // New pattern: load VecPtr 261 InstructionCost NewCost = 262 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind); 263 // Optionally, we are shuffling the loaded vector element(s) into place. 264 // For the mask set everything but element 0 to undef to prevent poison from 265 // propagating from the extra loaded memory. This will also optionally 266 // shrink/grow the vector from the loaded size to the output size. 267 // We assume this operation has no cost in codegen if there was no offset. 268 // Note that we could use freeze to avoid poison problems, but then we might 269 // still need a shuffle to change the vector size. 270 auto *Ty = cast<FixedVectorType>(I.getType()); 271 unsigned OutputNumElts = Ty->getNumElements(); 272 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem); 273 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); 274 Mask[0] = OffsetEltIndex; 275 if (OffsetEltIndex) 276 NewCost += 277 TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind); 278 279 // We can aggressively convert to the vector form because the backend can 280 // invert this transform if it does not result in a performance win. 281 if (OldCost < NewCost || !NewCost.isValid()) 282 return false; 283 284 // It is safe and potentially profitable to load a vector directly: 285 // inselt undef, load Scalar, 0 --> load VecPtr 286 IRBuilder<> Builder(Load); 287 Value *CastedPtr = 288 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 289 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); 290 VecLd = Builder.CreateShuffleVector(VecLd, Mask); 291 292 replaceValue(I, *VecLd); 293 ++NumVecLoad; 294 return true; 295 } 296 297 /// If we are loading a vector and then inserting it into a larger vector with 298 /// undefined elements, try to load the larger vector and eliminate the insert. 299 /// This removes a shuffle in IR and may allow combining of other loaded values. 300 bool VectorCombine::widenSubvectorLoad(Instruction &I) { 301 // Match subvector insert of fixed vector. 302 auto *Shuf = cast<ShuffleVectorInst>(&I); 303 if (!Shuf->isIdentityWithPadding()) 304 return false; 305 306 // Allow a non-canonical shuffle mask that is choosing elements from op1. 307 unsigned NumOpElts = 308 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements(); 309 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) { 310 return M >= (int)(NumOpElts); 311 }); 312 313 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex)); 314 if (!canWidenLoad(Load, TTI)) 315 return false; 316 317 // We use minimal alignment (maximum flexibility) because we only care about 318 // the dereferenceable region. When calculating cost and creating a new op, 319 // we may use a larger value based on alignment attributes. 320 auto *Ty = cast<FixedVectorType>(I.getType()); 321 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 322 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 323 Align Alignment = Load->getAlign(); 324 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT)) 325 return false; 326 327 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 328 Type *LoadTy = Load->getType(); 329 unsigned AS = Load->getPointerAddressSpace(); 330 331 // Original pattern: insert_subvector (load PtrOp) 332 // This conservatively assumes that the cost of a subvector insert into an 333 // undef value is 0. We could add that cost if the cost model accurately 334 // reflects the real cost of that operation. 335 InstructionCost OldCost = 336 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 337 338 // New pattern: load PtrOp 339 InstructionCost NewCost = 340 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind); 341 342 // We can aggressively convert to the vector form because the backend can 343 // invert this transform if it does not result in a performance win. 344 if (OldCost < NewCost || !NewCost.isValid()) 345 return false; 346 347 IRBuilder<> Builder(Load); 348 Value *CastedPtr = 349 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 350 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment); 351 replaceValue(I, *VecLd); 352 ++NumVecLoad; 353 return true; 354 } 355 356 /// Determine which, if any, of the inputs should be replaced by a shuffle 357 /// followed by extract from a different index. 358 ExtractElementInst *VectorCombine::getShuffleExtract( 359 ExtractElementInst *Ext0, ExtractElementInst *Ext1, 360 unsigned PreferredExtractIndex = InvalidIndex) const { 361 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 362 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 363 assert(Index0C && Index1C && "Expected constant extract indexes"); 364 365 unsigned Index0 = Index0C->getZExtValue(); 366 unsigned Index1 = Index1C->getZExtValue(); 367 368 // If the extract indexes are identical, no shuffle is needed. 369 if (Index0 == Index1) 370 return nullptr; 371 372 Type *VecTy = Ext0->getVectorOperand()->getType(); 373 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); 374 InstructionCost Cost0 = 375 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 376 InstructionCost Cost1 = 377 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 378 379 // If both costs are invalid no shuffle is needed 380 if (!Cost0.isValid() && !Cost1.isValid()) 381 return nullptr; 382 383 // We are extracting from 2 different indexes, so one operand must be shuffled 384 // before performing a vector operation and/or extract. The more expensive 385 // extract will be replaced by a shuffle. 386 if (Cost0 > Cost1) 387 return Ext0; 388 if (Cost1 > Cost0) 389 return Ext1; 390 391 // If the costs are equal and there is a preferred extract index, shuffle the 392 // opposite operand. 393 if (PreferredExtractIndex == Index0) 394 return Ext1; 395 if (PreferredExtractIndex == Index1) 396 return Ext0; 397 398 // Otherwise, replace the extract with the higher index. 399 return Index0 > Index1 ? Ext0 : Ext1; 400 } 401 402 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 403 /// vector operation(s) followed by extract. Return true if the existing 404 /// instructions are cheaper than a vector alternative. Otherwise, return false 405 /// and if one of the extracts should be transformed to a shufflevector, set 406 /// \p ConvertToShuffle to that extract instruction. 407 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, 408 ExtractElementInst *Ext1, 409 const Instruction &I, 410 ExtractElementInst *&ConvertToShuffle, 411 unsigned PreferredExtractIndex) { 412 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 413 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 414 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes"); 415 416 unsigned Opcode = I.getOpcode(); 417 Value *Ext0Src = Ext0->getVectorOperand(); 418 Value *Ext1Src = Ext1->getVectorOperand(); 419 Type *ScalarTy = Ext0->getType(); 420 auto *VecTy = cast<VectorType>(Ext0Src->getType()); 421 InstructionCost ScalarOpCost, VectorOpCost; 422 423 // Get cost estimates for scalar and vector versions of the operation. 424 bool IsBinOp = Instruction::isBinaryOp(Opcode); 425 if (IsBinOp) { 426 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 427 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 428 } else { 429 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 430 "Expected a compare"); 431 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 432 ScalarOpCost = TTI.getCmpSelInstrCost( 433 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 434 VectorOpCost = TTI.getCmpSelInstrCost( 435 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 436 } 437 438 // Get cost estimates for the extract elements. These costs will factor into 439 // both sequences. 440 unsigned Ext0Index = Ext0IndexC->getZExtValue(); 441 unsigned Ext1Index = Ext1IndexC->getZExtValue(); 442 443 InstructionCost Extract0Cost = 444 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index); 445 InstructionCost Extract1Cost = 446 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index); 447 448 // A more expensive extract will always be replaced by a splat shuffle. 449 // For example, if Ext0 is more expensive: 450 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 451 // extelt (opcode (splat V0, Ext0), V1), Ext1 452 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 453 // check the cost of creating a broadcast shuffle and shuffling both 454 // operands to element 0. 455 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index; 456 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index; 457 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 458 459 // Extra uses of the extracts mean that we include those costs in the 460 // vector total because those instructions will not be eliminated. 461 InstructionCost OldCost, NewCost; 462 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) { 463 // Handle a special case. If the 2 extracts are identical, adjust the 464 // formulas to account for that. The extra use charge allows for either the 465 // CSE'd pattern or an unoptimized form with identical values: 466 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 467 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 468 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 469 OldCost = CheapExtractCost + ScalarOpCost; 470 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 471 } else { 472 // Handle the general case. Each extract is actually a different value: 473 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 474 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 475 NewCost = VectorOpCost + CheapExtractCost + 476 !Ext0->hasOneUse() * Extract0Cost + 477 !Ext1->hasOneUse() * Extract1Cost; 478 } 479 480 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); 481 if (ConvertToShuffle) { 482 if (IsBinOp && DisableBinopExtractShuffle) 483 return true; 484 485 // If we are extracting from 2 different indexes, then one operand must be 486 // shuffled before performing the vector operation. The shuffle mask is 487 // poison except for 1 lane that is being translated to the remaining 488 // extraction lane. Therefore, it is a splat shuffle. Ex: 489 // ShufMask = { poison, poison, 0, poison } 490 // TODO: The cost model has an option for a "broadcast" shuffle 491 // (splat-from-element-0), but no option for a more general splat. 492 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) { 493 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(), 494 PoisonMaskElem); 495 ShuffleMask[BestInsIndex] = BestExtIndex; 496 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 497 VecTy, ShuffleMask, CostKind, 0, nullptr, 498 {ConvertToShuffle}); 499 } else { 500 NewCost += 501 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy, 502 {}, CostKind, 0, nullptr, {ConvertToShuffle}); 503 } 504 } 505 506 // Aggressively form a vector op if the cost is equal because the transform 507 // may enable further optimization. 508 // Codegen can reverse this transform (scalarize) if it was not profitable. 509 return OldCost < NewCost; 510 } 511 512 /// Create a shuffle that translates (shifts) 1 element from the input vector 513 /// to a new element location. 514 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, 515 unsigned NewIndex, IRBuilder<> &Builder) { 516 // The shuffle mask is poison except for 1 lane that is being translated 517 // to the new element index. Example for OldIndex == 2 and NewIndex == 0: 518 // ShufMask = { 2, poison, poison, poison } 519 auto *VecTy = cast<FixedVectorType>(Vec->getType()); 520 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 521 ShufMask[NewIndex] = OldIndex; 522 return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); 523 } 524 525 /// Given an extract element instruction with constant index operand, shuffle 526 /// the source vector (shift the scalar element) to a NewIndex for extraction. 527 /// Return null if the input can be constant folded, so that we are not creating 528 /// unnecessary instructions. 529 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt, 530 unsigned NewIndex, 531 IRBuilder<> &Builder) { 532 // Shufflevectors can only be created for fixed-width vectors. 533 Value *X = ExtElt->getVectorOperand(); 534 if (!isa<FixedVectorType>(X->getType())) 535 return nullptr; 536 537 // If the extract can be constant-folded, this code is unsimplified. Defer 538 // to other passes to handle that. 539 Value *C = ExtElt->getIndexOperand(); 540 assert(isa<ConstantInt>(C) && "Expected a constant index operand"); 541 if (isa<Constant>(X)) 542 return nullptr; 543 544 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(), 545 NewIndex, Builder); 546 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex)); 547 } 548 549 /// Try to reduce extract element costs by converting scalar compares to vector 550 /// compares followed by extract. 551 /// cmp (ext0 V0, C), (ext1 V1, C) 552 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0, 553 ExtractElementInst *Ext1, Instruction &I) { 554 assert(isa<CmpInst>(&I) && "Expected a compare"); 555 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 556 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 557 "Expected matching constant extract indexes"); 558 559 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 560 ++NumVecCmp; 561 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 562 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 563 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1); 564 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand()); 565 replaceValue(I, *NewExt); 566 } 567 568 /// Try to reduce extract element costs by converting scalar binops to vector 569 /// binops followed by extract. 570 /// bo (ext0 V0, C), (ext1 V1, C) 571 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0, 572 ExtractElementInst *Ext1, Instruction &I) { 573 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 574 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 575 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 576 "Expected matching constant extract indexes"); 577 578 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 579 ++NumVecBO; 580 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 581 Value *VecBO = 582 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 583 584 // All IR flags are safe to back-propagate because any potential poison 585 // created in unused vector elements is discarded by the extract. 586 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 587 VecBOInst->copyIRFlags(&I); 588 589 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand()); 590 replaceValue(I, *NewExt); 591 } 592 593 /// Match an instruction with extracted vector operands. 594 bool VectorCombine::foldExtractExtract(Instruction &I) { 595 // It is not safe to transform things like div, urem, etc. because we may 596 // create undefined behavior when executing those on unknown vector elements. 597 if (!isSafeToSpeculativelyExecute(&I)) 598 return false; 599 600 Instruction *I0, *I1; 601 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 602 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) && 603 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1)))) 604 return false; 605 606 Value *V0, *V1; 607 uint64_t C0, C1; 608 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) || 609 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) || 610 V0->getType() != V1->getType()) 611 return false; 612 613 // If the scalar value 'I' is going to be re-inserted into a vector, then try 614 // to create an extract to that same element. The extract/insert can be 615 // reduced to a "select shuffle". 616 // TODO: If we add a larger pattern match that starts from an insert, this 617 // probably becomes unnecessary. 618 auto *Ext0 = cast<ExtractElementInst>(I0); 619 auto *Ext1 = cast<ExtractElementInst>(I1); 620 uint64_t InsertIndex = InvalidIndex; 621 if (I.hasOneUse()) 622 match(I.user_back(), 623 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); 624 625 ExtractElementInst *ExtractToChange; 626 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) 627 return false; 628 629 if (ExtractToChange) { 630 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0; 631 ExtractElementInst *NewExtract = 632 translateExtract(ExtractToChange, CheapExtractIdx, Builder); 633 if (!NewExtract) 634 return false; 635 if (ExtractToChange == Ext0) 636 Ext0 = NewExtract; 637 else 638 Ext1 = NewExtract; 639 } 640 641 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 642 foldExtExtCmp(Ext0, Ext1, I); 643 else 644 foldExtExtBinop(Ext0, Ext1, I); 645 646 Worklist.push(Ext0); 647 Worklist.push(Ext1); 648 return true; 649 } 650 651 /// Try to replace an extract + scalar fneg + insert with a vector fneg + 652 /// shuffle. 653 bool VectorCombine::foldInsExtFNeg(Instruction &I) { 654 // Match an insert (op (extract)) pattern. 655 Value *DestVec; 656 uint64_t Index; 657 Instruction *FNeg; 658 if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)), 659 m_ConstantInt(Index)))) 660 return false; 661 662 // Note: This handles the canonical fneg instruction and "fsub -0.0, X". 663 Value *SrcVec; 664 Instruction *Extract; 665 if (!match(FNeg, m_FNeg(m_CombineAnd( 666 m_Instruction(Extract), 667 m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index)))))) 668 return false; 669 670 auto *VecTy = cast<FixedVectorType>(I.getType()); 671 auto *ScalarTy = VecTy->getScalarType(); 672 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType()); 673 if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType()) 674 return false; 675 676 // Ignore bogus insert/extract index. 677 unsigned NumElts = VecTy->getNumElements(); 678 if (Index >= NumElts) 679 return false; 680 681 // We are inserting the negated element into the same lane that we extracted 682 // from. This is equivalent to a select-shuffle that chooses all but the 683 // negated element from the destination vector. 684 SmallVector<int> Mask(NumElts); 685 std::iota(Mask.begin(), Mask.end(), 0); 686 Mask[Index] = Index + NumElts; 687 InstructionCost OldCost = 688 TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) + 689 TTI.getVectorInstrCost(I, VecTy, CostKind, Index); 690 691 // If the extract has one use, it will be eliminated, so count it in the 692 // original cost. If it has more than one use, ignore the cost because it will 693 // be the same before/after. 694 if (Extract->hasOneUse()) 695 OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index); 696 697 InstructionCost NewCost = 698 TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) + 699 TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind); 700 701 bool NeedLenChg = SrcVecTy->getNumElements() != NumElts; 702 // If the lengths of the two vectors are not equal, 703 // we need to add a length-change vector. Add this cost. 704 SmallVector<int> SrcMask; 705 if (NeedLenChg) { 706 SrcMask.assign(NumElts, PoisonMaskElem); 707 SrcMask[Index] = Index; 708 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 709 SrcVecTy, SrcMask, CostKind); 710 } 711 712 if (NewCost > OldCost) 713 return false; 714 715 Value *NewShuf; 716 // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index 717 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg); 718 if (NeedLenChg) { 719 // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask 720 Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask); 721 NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask); 722 } else { 723 // shuffle DestVec, (fneg SrcVec), Mask 724 NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask); 725 } 726 727 replaceValue(I, *NewShuf); 728 return true; 729 } 730 731 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 732 /// destination type followed by shuffle. This can enable further transforms by 733 /// moving bitcasts or shuffles together. 734 bool VectorCombine::foldBitcastShuffle(Instruction &I) { 735 Value *V0, *V1; 736 ArrayRef<int> Mask; 737 if (!match(&I, m_BitCast(m_OneUse( 738 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask)))))) 739 return false; 740 741 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for 742 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle 743 // mask for scalable type is a splat or not. 744 // 2) Disallow non-vector casts. 745 // TODO: We could allow any shuffle. 746 auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); 747 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType()); 748 if (!DestTy || !SrcTy) 749 return false; 750 751 unsigned DestEltSize = DestTy->getScalarSizeInBits(); 752 unsigned SrcEltSize = SrcTy->getScalarSizeInBits(); 753 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0) 754 return false; 755 756 bool IsUnary = isa<UndefValue>(V1); 757 758 // For binary shuffles, only fold bitcast(shuffle(X,Y)) 759 // if it won't increase the number of bitcasts. 760 if (!IsUnary) { 761 auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType()); 762 auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType()); 763 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) && 764 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType())) 765 return false; 766 } 767 768 SmallVector<int, 16> NewMask; 769 if (DestEltSize <= SrcEltSize) { 770 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 771 // always be expanded to the equivalent form choosing narrower elements. 772 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask"); 773 unsigned ScaleFactor = SrcEltSize / DestEltSize; 774 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 775 } else { 776 // The bitcast is from narrow elements to wide elements. The shuffle mask 777 // must choose consecutive elements to allow casting first. 778 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask"); 779 unsigned ScaleFactor = DestEltSize / SrcEltSize; 780 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 781 return false; 782 } 783 784 // Bitcast the shuffle src - keep its original width but using the destination 785 // scalar type. 786 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; 787 auto *NewShuffleTy = 788 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); 789 auto *OldShuffleTy = 790 FixedVectorType::get(SrcTy->getScalarType(), Mask.size()); 791 unsigned NumOps = IsUnary ? 1 : 2; 792 793 // The new shuffle must not cost more than the old shuffle. 794 TargetTransformInfo::ShuffleKind SK = 795 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc 796 : TargetTransformInfo::SK_PermuteTwoSrc; 797 798 InstructionCost NewCost = 799 TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) + 800 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy, 801 TargetTransformInfo::CastContextHint::None, 802 CostKind)); 803 InstructionCost OldCost = 804 TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) + 805 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy, 806 TargetTransformInfo::CastContextHint::None, 807 CostKind); 808 809 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: " 810 << OldCost << " vs NewCost: " << NewCost << "\n"); 811 812 if (NewCost > OldCost || !NewCost.isValid()) 813 return false; 814 815 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC' 816 ++NumShufOfBitcast; 817 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy); 818 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy); 819 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask); 820 replaceValue(I, *Shuf); 821 return true; 822 } 823 824 /// VP Intrinsics whose vector operands are both splat values may be simplified 825 /// into the scalar version of the operation and the result splatted. This 826 /// can lead to scalarization down the line. 827 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { 828 if (!isa<VPIntrinsic>(I)) 829 return false; 830 VPIntrinsic &VPI = cast<VPIntrinsic>(I); 831 Value *Op0 = VPI.getArgOperand(0); 832 Value *Op1 = VPI.getArgOperand(1); 833 834 if (!isSplatValue(Op0) || !isSplatValue(Op1)) 835 return false; 836 837 // Check getSplatValue early in this function, to avoid doing unnecessary 838 // work. 839 Value *ScalarOp0 = getSplatValue(Op0); 840 Value *ScalarOp1 = getSplatValue(Op1); 841 if (!ScalarOp0 || !ScalarOp1) 842 return false; 843 844 // For the binary VP intrinsics supported here, the result on disabled lanes 845 // is a poison value. For now, only do this simplification if all lanes 846 // are active. 847 // TODO: Relax the condition that all lanes are active by using insertelement 848 // on inactive lanes. 849 auto IsAllTrueMask = [](Value *MaskVal) { 850 if (Value *SplattedVal = getSplatValue(MaskVal)) 851 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 852 return ConstValue->isAllOnesValue(); 853 return false; 854 }; 855 if (!IsAllTrueMask(VPI.getArgOperand(2))) 856 return false; 857 858 // Check to make sure we support scalarization of the intrinsic 859 Intrinsic::ID IntrID = VPI.getIntrinsicID(); 860 if (!VPBinOpIntrinsic::isVPBinOp(IntrID)) 861 return false; 862 863 // Calculate cost of splatting both operands into vectors and the vector 864 // intrinsic 865 VectorType *VecTy = cast<VectorType>(VPI.getType()); 866 SmallVector<int> Mask; 867 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy)) 868 Mask.resize(FVTy->getNumElements(), 0); 869 InstructionCost SplatCost = 870 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + 871 TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask, 872 CostKind); 873 874 // Calculate the cost of the VP Intrinsic 875 SmallVector<Type *, 4> Args; 876 for (Value *V : VPI.args()) 877 Args.push_back(V->getType()); 878 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args); 879 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 880 InstructionCost OldCost = 2 * SplatCost + VectorOpCost; 881 882 // Determine scalar opcode 883 std::optional<unsigned> FunctionalOpcode = 884 VPI.getFunctionalOpcode(); 885 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; 886 if (!FunctionalOpcode) { 887 ScalarIntrID = VPI.getFunctionalIntrinsicID(); 888 if (!ScalarIntrID) 889 return false; 890 } 891 892 // Calculate cost of scalarizing 893 InstructionCost ScalarOpCost = 0; 894 if (ScalarIntrID) { 895 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args); 896 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 897 } else { 898 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode, 899 VecTy->getScalarType(), CostKind); 900 } 901 902 // The existing splats may be kept around if other instructions use them. 903 InstructionCost CostToKeepSplats = 904 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); 905 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; 906 907 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI 908 << "\n"); 909 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost 910 << ", Cost of scalarizing:" << NewCost << "\n"); 911 912 // We want to scalarize unless the vector variant actually has lower cost. 913 if (OldCost < NewCost || !NewCost.isValid()) 914 return false; 915 916 // Scalarize the intrinsic 917 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); 918 Value *EVL = VPI.getArgOperand(3); 919 920 // If the VP op might introduce UB or poison, we can scalarize it provided 921 // that we know the EVL > 0: If the EVL is zero, then the original VP op 922 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by 923 // scalarizing it. 924 bool SafeToSpeculate; 925 if (ScalarIntrID) 926 SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID) 927 .hasFnAttr(Attribute::AttrKind::Speculatable); 928 else 929 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode( 930 *FunctionalOpcode, &VPI, nullptr, &AC, &DT); 931 if (!SafeToSpeculate && 932 !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI))) 933 return false; 934 935 Value *ScalarVal = 936 ScalarIntrID 937 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID, 938 {ScalarOp0, ScalarOp1}) 939 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode), 940 ScalarOp0, ScalarOp1); 941 942 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal)); 943 return true; 944 } 945 946 /// Match a vector binop or compare instruction with at least one inserted 947 /// scalar operand and convert to scalar binop/cmp followed by insertelement. 948 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { 949 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 950 Value *Ins0, *Ins1; 951 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && 952 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) 953 return false; 954 955 // Do not convert the vector condition of a vector select into a scalar 956 // condition. That may cause problems for codegen because of differences in 957 // boolean formats and register-file transfers. 958 // TODO: Can we account for that in the cost model? 959 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE; 960 if (IsCmp) 961 for (User *U : I.users()) 962 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) 963 return false; 964 965 // Match against one or both scalar values being inserted into constant 966 // vectors: 967 // vec_op VecC0, (inselt VecC1, V1, Index) 968 // vec_op (inselt VecC0, V0, Index), VecC1 969 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) 970 // TODO: Deal with mismatched index constants and variable indexes? 971 Constant *VecC0 = nullptr, *VecC1 = nullptr; 972 Value *V0 = nullptr, *V1 = nullptr; 973 uint64_t Index0 = 0, Index1 = 0; 974 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), 975 m_ConstantInt(Index0))) && 976 !match(Ins0, m_Constant(VecC0))) 977 return false; 978 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), 979 m_ConstantInt(Index1))) && 980 !match(Ins1, m_Constant(VecC1))) 981 return false; 982 983 bool IsConst0 = !V0; 984 bool IsConst1 = !V1; 985 if (IsConst0 && IsConst1) 986 return false; 987 if (!IsConst0 && !IsConst1 && Index0 != Index1) 988 return false; 989 990 auto *VecTy0 = cast<VectorType>(Ins0->getType()); 991 auto *VecTy1 = cast<VectorType>(Ins1->getType()); 992 if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || 993 VecTy1->getElementCount().getKnownMinValue() <= Index1) 994 return false; 995 996 // Bail for single insertion if it is a load. 997 // TODO: Handle this once getVectorInstrCost can cost for load/stores. 998 auto *I0 = dyn_cast_or_null<Instruction>(V0); 999 auto *I1 = dyn_cast_or_null<Instruction>(V1); 1000 if ((IsConst0 && I1 && I1->mayReadFromMemory()) || 1001 (IsConst1 && I0 && I0->mayReadFromMemory())) 1002 return false; 1003 1004 uint64_t Index = IsConst0 ? Index1 : Index0; 1005 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); 1006 Type *VecTy = I.getType(); 1007 assert(VecTy->isVectorTy() && 1008 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && 1009 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || 1010 ScalarTy->isPointerTy()) && 1011 "Unexpected types for insert element into binop or cmp"); 1012 1013 unsigned Opcode = I.getOpcode(); 1014 InstructionCost ScalarOpCost, VectorOpCost; 1015 if (IsCmp) { 1016 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 1017 ScalarOpCost = TTI.getCmpSelInstrCost( 1018 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 1019 VectorOpCost = TTI.getCmpSelInstrCost( 1020 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 1021 } else { 1022 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 1023 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 1024 } 1025 1026 // Get cost estimate for the insert element. This cost will factor into 1027 // both sequences. 1028 InstructionCost InsertCost = TTI.getVectorInstrCost( 1029 Instruction::InsertElement, VecTy, CostKind, Index); 1030 InstructionCost OldCost = 1031 (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; 1032 InstructionCost NewCost = ScalarOpCost + InsertCost + 1033 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + 1034 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); 1035 1036 // We want to scalarize unless the vector variant actually has lower cost. 1037 if (OldCost < NewCost || !NewCost.isValid()) 1038 return false; 1039 1040 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 1041 // inselt NewVecC, (scalar_op V0, V1), Index 1042 if (IsCmp) 1043 ++NumScalarCmp; 1044 else 1045 ++NumScalarBO; 1046 1047 // For constant cases, extract the scalar element, this should constant fold. 1048 if (IsConst0) 1049 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); 1050 if (IsConst1) 1051 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); 1052 1053 Value *Scalar = 1054 IsCmp ? Builder.CreateCmp(Pred, V0, V1) 1055 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); 1056 1057 Scalar->setName(I.getName() + ".scalar"); 1058 1059 // All IR flags are safe to back-propagate. There is no potential for extra 1060 // poison to be created by the scalar instruction. 1061 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 1062 ScalarInst->copyIRFlags(&I); 1063 1064 // Fold the vector constants in the original vectors into a new base vector. 1065 Value *NewVecC = 1066 IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) 1067 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); 1068 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 1069 replaceValue(I, *Insert); 1070 return true; 1071 } 1072 1073 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of 1074 /// a vector into vector operations followed by extract. Note: The SLP pass 1075 /// may miss this pattern because of implementation problems. 1076 bool VectorCombine::foldExtractedCmps(Instruction &I) { 1077 auto *BI = dyn_cast<BinaryOperator>(&I); 1078 1079 // We are looking for a scalar binop of booleans. 1080 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) 1081 if (!BI || !I.getType()->isIntegerTy(1)) 1082 return false; 1083 1084 // The compare predicates should match, and each compare should have a 1085 // constant operand. 1086 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); 1087 Instruction *I0, *I1; 1088 Constant *C0, *C1; 1089 CmpPredicate P0, P1; 1090 // FIXME: Use CmpPredicate::getMatching here. 1091 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) || 1092 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || 1093 P0 != static_cast<CmpInst::Predicate>(P1)) 1094 return false; 1095 1096 // The compare operands must be extracts of the same vector with constant 1097 // extract indexes. 1098 Value *X; 1099 uint64_t Index0, Index1; 1100 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) || 1101 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))) 1102 return false; 1103 1104 auto *Ext0 = cast<ExtractElementInst>(I0); 1105 auto *Ext1 = cast<ExtractElementInst>(I1); 1106 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind); 1107 if (!ConvertToShuf) 1108 return false; 1109 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) && 1110 "Unknown ExtractElementInst"); 1111 1112 // The original scalar pattern is: 1113 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) 1114 CmpInst::Predicate Pred = P0; 1115 unsigned CmpOpcode = 1116 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp; 1117 auto *VecTy = dyn_cast<FixedVectorType>(X->getType()); 1118 if (!VecTy) 1119 return false; 1120 1121 InstructionCost Ext0Cost = 1122 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 1123 InstructionCost Ext1Cost = 1124 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 1125 InstructionCost CmpCost = TTI.getCmpSelInstrCost( 1126 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred, 1127 CostKind); 1128 1129 InstructionCost OldCost = 1130 Ext0Cost + Ext1Cost + CmpCost * 2 + 1131 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind); 1132 1133 // The proposed vector pattern is: 1134 // vcmp = cmp Pred X, VecC 1135 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 1136 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; 1137 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; 1138 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); 1139 InstructionCost NewCost = TTI.getCmpSelInstrCost( 1140 CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred, 1141 CostKind); 1142 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 1143 ShufMask[CheapIndex] = ExpensiveIndex; 1144 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, 1145 ShufMask, CostKind); 1146 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind); 1147 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex); 1148 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost; 1149 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost; 1150 1151 // Aggressively form vector ops if the cost is equal because the transform 1152 // may enable further optimization. 1153 // Codegen can reverse this transform (scalarize) if it was not profitable. 1154 if (OldCost < NewCost || !NewCost.isValid()) 1155 return false; 1156 1157 // Create a vector constant from the 2 scalar constants. 1158 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), 1159 PoisonValue::get(VecTy->getElementType())); 1160 CmpC[Index0] = C0; 1161 CmpC[Index1] = C1; 1162 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); 1163 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); 1164 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp; 1165 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf; 1166 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS); 1167 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); 1168 replaceValue(I, *NewExt); 1169 ++NumVecCmpBO; 1170 return true; 1171 } 1172 1173 // Check if memory loc modified between two instrs in the same BB 1174 static bool isMemModifiedBetween(BasicBlock::iterator Begin, 1175 BasicBlock::iterator End, 1176 const MemoryLocation &Loc, AAResults &AA) { 1177 unsigned NumScanned = 0; 1178 return std::any_of(Begin, End, [&](const Instruction &Instr) { 1179 return isModSet(AA.getModRefInfo(&Instr, Loc)) || 1180 ++NumScanned > MaxInstrsToScan; 1181 }); 1182 } 1183 1184 namespace { 1185 /// Helper class to indicate whether a vector index can be safely scalarized and 1186 /// if a freeze needs to be inserted. 1187 class ScalarizationResult { 1188 enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; 1189 1190 StatusTy Status; 1191 Value *ToFreeze; 1192 1193 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) 1194 : Status(Status), ToFreeze(ToFreeze) {} 1195 1196 public: 1197 ScalarizationResult(const ScalarizationResult &Other) = default; 1198 ~ScalarizationResult() { 1199 assert(!ToFreeze && "freeze() not called with ToFreeze being set"); 1200 } 1201 1202 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } 1203 static ScalarizationResult safe() { return {StatusTy::Safe}; } 1204 static ScalarizationResult safeWithFreeze(Value *ToFreeze) { 1205 return {StatusTy::SafeWithFreeze, ToFreeze}; 1206 } 1207 1208 /// Returns true if the index can be scalarize without requiring a freeze. 1209 bool isSafe() const { return Status == StatusTy::Safe; } 1210 /// Returns true if the index cannot be scalarized. 1211 bool isUnsafe() const { return Status == StatusTy::Unsafe; } 1212 /// Returns true if the index can be scalarize, but requires inserting a 1213 /// freeze. 1214 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } 1215 1216 /// Reset the state of Unsafe and clear ToFreze if set. 1217 void discard() { 1218 ToFreeze = nullptr; 1219 Status = StatusTy::Unsafe; 1220 } 1221 1222 /// Freeze the ToFreeze and update the use in \p User to use it. 1223 void freeze(IRBuilder<> &Builder, Instruction &UserI) { 1224 assert(isSafeWithFreeze() && 1225 "should only be used when freezing is required"); 1226 assert(is_contained(ToFreeze->users(), &UserI) && 1227 "UserI must be a user of ToFreeze"); 1228 IRBuilder<>::InsertPointGuard Guard(Builder); 1229 Builder.SetInsertPoint(cast<Instruction>(&UserI)); 1230 Value *Frozen = 1231 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); 1232 for (Use &U : make_early_inc_range((UserI.operands()))) 1233 if (U.get() == ToFreeze) 1234 U.set(Frozen); 1235 1236 ToFreeze = nullptr; 1237 } 1238 }; 1239 } // namespace 1240 1241 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p 1242 /// Idx. \p Idx must access a valid vector element. 1243 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, 1244 Instruction *CtxI, 1245 AssumptionCache &AC, 1246 const DominatorTree &DT) { 1247 // We do checks for both fixed vector types and scalable vector types. 1248 // This is the number of elements of fixed vector types, 1249 // or the minimum number of elements of scalable vector types. 1250 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue(); 1251 1252 if (auto *C = dyn_cast<ConstantInt>(Idx)) { 1253 if (C->getValue().ult(NumElements)) 1254 return ScalarizationResult::safe(); 1255 return ScalarizationResult::unsafe(); 1256 } 1257 1258 unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); 1259 APInt Zero(IntWidth, 0); 1260 APInt MaxElts(IntWidth, NumElements); 1261 ConstantRange ValidIndices(Zero, MaxElts); 1262 ConstantRange IdxRange(IntWidth, true); 1263 1264 if (isGuaranteedNotToBePoison(Idx, &AC)) { 1265 if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false, 1266 true, &AC, CtxI, &DT))) 1267 return ScalarizationResult::safe(); 1268 return ScalarizationResult::unsafe(); 1269 } 1270 1271 // If the index may be poison, check if we can insert a freeze before the 1272 // range of the index is restricted. 1273 Value *IdxBase; 1274 ConstantInt *CI; 1275 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { 1276 IdxRange = IdxRange.binaryAnd(CI->getValue()); 1277 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { 1278 IdxRange = IdxRange.urem(CI->getValue()); 1279 } 1280 1281 if (ValidIndices.contains(IdxRange)) 1282 return ScalarizationResult::safeWithFreeze(IdxBase); 1283 return ScalarizationResult::unsafe(); 1284 } 1285 1286 /// The memory operation on a vector of \p ScalarType had alignment of 1287 /// \p VectorAlignment. Compute the maximal, but conservatively correct, 1288 /// alignment that will be valid for the memory operation on a single scalar 1289 /// element of the same type with index \p Idx. 1290 static Align computeAlignmentAfterScalarization(Align VectorAlignment, 1291 Type *ScalarType, Value *Idx, 1292 const DataLayout &DL) { 1293 if (auto *C = dyn_cast<ConstantInt>(Idx)) 1294 return commonAlignment(VectorAlignment, 1295 C->getZExtValue() * DL.getTypeStoreSize(ScalarType)); 1296 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType)); 1297 } 1298 1299 // Combine patterns like: 1300 // %0 = load <4 x i32>, <4 x i32>* %a 1301 // %1 = insertelement <4 x i32> %0, i32 %b, i32 1 1302 // store <4 x i32> %1, <4 x i32>* %a 1303 // to: 1304 // %0 = bitcast <4 x i32>* %a to i32* 1305 // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 1306 // store i32 %b, i32* %1 1307 bool VectorCombine::foldSingleElementStore(Instruction &I) { 1308 auto *SI = cast<StoreInst>(&I); 1309 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) 1310 return false; 1311 1312 // TODO: Combine more complicated patterns (multiple insert) by referencing 1313 // TargetTransformInfo. 1314 Instruction *Source; 1315 Value *NewElement; 1316 Value *Idx; 1317 if (!match(SI->getValueOperand(), 1318 m_InsertElt(m_Instruction(Source), m_Value(NewElement), 1319 m_Value(Idx)))) 1320 return false; 1321 1322 if (auto *Load = dyn_cast<LoadInst>(Source)) { 1323 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType()); 1324 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); 1325 // Don't optimize for atomic/volatile load or store. Ensure memory is not 1326 // modified between, vector type matches store size, and index is inbounds. 1327 if (!Load->isSimple() || Load->getParent() != SI->getParent() || 1328 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || 1329 SrcAddr != SI->getPointerOperand()->stripPointerCasts()) 1330 return false; 1331 1332 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); 1333 if (ScalarizableIdx.isUnsafe() || 1334 isMemModifiedBetween(Load->getIterator(), SI->getIterator(), 1335 MemoryLocation::get(SI), AA)) 1336 return false; 1337 1338 if (ScalarizableIdx.isSafeWithFreeze()) 1339 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); 1340 Value *GEP = Builder.CreateInBoundsGEP( 1341 SI->getValueOperand()->getType(), SI->getPointerOperand(), 1342 {ConstantInt::get(Idx->getType(), 0), Idx}); 1343 StoreInst *NSI = Builder.CreateStore(NewElement, GEP); 1344 NSI->copyMetadata(*SI); 1345 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1346 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx, 1347 *DL); 1348 NSI->setAlignment(ScalarOpAlignment); 1349 replaceValue(I, *NSI); 1350 eraseInstruction(I); 1351 return true; 1352 } 1353 1354 return false; 1355 } 1356 1357 /// Try to scalarize vector loads feeding extractelement instructions. 1358 bool VectorCombine::scalarizeLoadExtract(Instruction &I) { 1359 Value *Ptr; 1360 if (!match(&I, m_Load(m_Value(Ptr)))) 1361 return false; 1362 1363 auto *VecTy = cast<VectorType>(I.getType()); 1364 auto *LI = cast<LoadInst>(&I); 1365 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) 1366 return false; 1367 1368 InstructionCost OriginalCost = 1369 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), 1370 LI->getPointerAddressSpace(), CostKind); 1371 InstructionCost ScalarizedCost = 0; 1372 1373 Instruction *LastCheckedInst = LI; 1374 unsigned NumInstChecked = 0; 1375 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze; 1376 auto FailureGuard = make_scope_exit([&]() { 1377 // If the transform is aborted, discard the ScalarizationResults. 1378 for (auto &Pair : NeedFreeze) 1379 Pair.second.discard(); 1380 }); 1381 1382 // Check if all users of the load are extracts with no memory modifications 1383 // between the load and the extract. Compute the cost of both the original 1384 // code and the scalarized version. 1385 for (User *U : LI->users()) { 1386 auto *UI = dyn_cast<ExtractElementInst>(U); 1387 if (!UI || UI->getParent() != LI->getParent()) 1388 return false; 1389 1390 // Check if any instruction between the load and the extract may modify 1391 // memory. 1392 if (LastCheckedInst->comesBefore(UI)) { 1393 for (Instruction &I : 1394 make_range(std::next(LI->getIterator()), UI->getIterator())) { 1395 // Bail out if we reached the check limit or the instruction may write 1396 // to memory. 1397 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory()) 1398 return false; 1399 NumInstChecked++; 1400 } 1401 LastCheckedInst = UI; 1402 } 1403 1404 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); 1405 if (ScalarIdx.isUnsafe()) 1406 return false; 1407 if (ScalarIdx.isSafeWithFreeze()) { 1408 NeedFreeze.try_emplace(UI, ScalarIdx); 1409 ScalarIdx.discard(); 1410 } 1411 1412 auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); 1413 OriginalCost += 1414 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, 1415 Index ? Index->getZExtValue() : -1); 1416 ScalarizedCost += 1417 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(), 1418 Align(1), LI->getPointerAddressSpace(), CostKind); 1419 ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType()); 1420 } 1421 1422 if (ScalarizedCost >= OriginalCost) 1423 return false; 1424 1425 // Replace extracts with narrow scalar loads. 1426 for (User *U : LI->users()) { 1427 auto *EI = cast<ExtractElementInst>(U); 1428 Value *Idx = EI->getOperand(1); 1429 1430 // Insert 'freeze' for poison indexes. 1431 auto It = NeedFreeze.find(EI); 1432 if (It != NeedFreeze.end()) 1433 It->second.freeze(Builder, *cast<Instruction>(Idx)); 1434 1435 Builder.SetInsertPoint(EI); 1436 Value *GEP = 1437 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx}); 1438 auto *NewLoad = cast<LoadInst>(Builder.CreateLoad( 1439 VecTy->getElementType(), GEP, EI->getName() + ".scalar")); 1440 1441 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1442 LI->getAlign(), VecTy->getElementType(), Idx, *DL); 1443 NewLoad->setAlignment(ScalarOpAlignment); 1444 1445 replaceValue(*EI, *NewLoad); 1446 } 1447 1448 FailureGuard.release(); 1449 return true; 1450 } 1451 1452 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" 1453 /// to "(bitcast (concat X, Y))" 1454 /// where X/Y are bitcasted from i1 mask vectors. 1455 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) { 1456 Type *Ty = I.getType(); 1457 if (!Ty->isIntegerTy()) 1458 return false; 1459 1460 // TODO: Add big endian test coverage 1461 if (DL->isBigEndian()) 1462 return false; 1463 1464 // Restrict to disjoint cases so the mask vectors aren't overlapping. 1465 Instruction *X, *Y; 1466 if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y)))) 1467 return false; 1468 1469 // Allow both sources to contain shl, to handle more generic pattern: 1470 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))" 1471 Value *SrcX; 1472 uint64_t ShAmtX = 0; 1473 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) && 1474 !match(X, m_OneUse( 1475 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))), 1476 m_ConstantInt(ShAmtX))))) 1477 return false; 1478 1479 Value *SrcY; 1480 uint64_t ShAmtY = 0; 1481 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) && 1482 !match(Y, m_OneUse( 1483 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))), 1484 m_ConstantInt(ShAmtY))))) 1485 return false; 1486 1487 // Canonicalize larger shift to the RHS. 1488 if (ShAmtX > ShAmtY) { 1489 std::swap(X, Y); 1490 std::swap(SrcX, SrcY); 1491 std::swap(ShAmtX, ShAmtY); 1492 } 1493 1494 // Ensure both sources are matching vXi1 bool mask types, and that the shift 1495 // difference is the mask width so they can be easily concatenated together. 1496 uint64_t ShAmtDiff = ShAmtY - ShAmtX; 1497 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0); 1498 unsigned BitWidth = Ty->getPrimitiveSizeInBits(); 1499 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType()); 1500 if (!MaskTy || SrcX->getType() != SrcY->getType() || 1501 !MaskTy->getElementType()->isIntegerTy(1) || 1502 MaskTy->getNumElements() != ShAmtDiff || 1503 MaskTy->getNumElements() > (BitWidth / 2)) 1504 return false; 1505 1506 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy); 1507 auto *ConcatIntTy = 1508 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements()); 1509 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff); 1510 1511 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements()); 1512 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 1513 1514 // TODO: Is it worth supporting multi use cases? 1515 InstructionCost OldCost = 0; 1516 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind); 1517 OldCost += 1518 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1519 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy, 1520 TTI::CastContextHint::None, CostKind); 1521 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy, 1522 TTI::CastContextHint::None, CostKind); 1523 1524 InstructionCost NewCost = 0; 1525 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy, 1526 ConcatMask, CostKind); 1527 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy, 1528 TTI::CastContextHint::None, CostKind); 1529 if (Ty != ConcatIntTy) 1530 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy, 1531 TTI::CastContextHint::None, CostKind); 1532 if (ShAmtX > 0) 1533 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1534 1535 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I 1536 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1537 << "\n"); 1538 1539 if (NewCost > OldCost) 1540 return false; 1541 1542 // Build bool mask concatenation, bitcast back to scalar integer, and perform 1543 // any residual zero-extension or shifting. 1544 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask); 1545 Worklist.pushValue(Concat); 1546 1547 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy); 1548 1549 if (Ty != ConcatIntTy) { 1550 Worklist.pushValue(Result); 1551 Result = Builder.CreateZExt(Result, Ty); 1552 } 1553 1554 if (ShAmtX > 0) { 1555 Worklist.pushValue(Result); 1556 Result = Builder.CreateShl(Result, ShAmtX); 1557 } 1558 1559 replaceValue(I, *Result); 1560 return true; 1561 } 1562 1563 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef" 1564 /// --> "binop (shuffle), (shuffle)". 1565 bool VectorCombine::foldPermuteOfBinops(Instruction &I) { 1566 BinaryOperator *BinOp; 1567 ArrayRef<int> OuterMask; 1568 if (!match(&I, 1569 m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask)))) 1570 return false; 1571 1572 // Don't introduce poison into div/rem. 1573 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem)) 1574 return false; 1575 1576 Value *Op00, *Op01; 1577 ArrayRef<int> Mask0; 1578 if (!match(BinOp->getOperand(0), 1579 m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0))))) 1580 return false; 1581 1582 Value *Op10, *Op11; 1583 ArrayRef<int> Mask1; 1584 if (!match(BinOp->getOperand(1), 1585 m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1))))) 1586 return false; 1587 1588 Instruction::BinaryOps Opcode = BinOp->getOpcode(); 1589 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1590 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType()); 1591 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType()); 1592 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType()); 1593 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty) 1594 return false; 1595 1596 unsigned NumSrcElts = BinOpTy->getNumElements(); 1597 1598 // Don't accept shuffles that reference the second operand in 1599 // div/rem or if its an undef arg. 1600 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) && 1601 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; })) 1602 return false; 1603 1604 // Merge outer / inner shuffles. 1605 SmallVector<int> NewMask0, NewMask1; 1606 for (int M : OuterMask) { 1607 if (M < 0 || M >= (int)NumSrcElts) { 1608 NewMask0.push_back(PoisonMaskElem); 1609 NewMask1.push_back(PoisonMaskElem); 1610 } else { 1611 NewMask0.push_back(Mask0[M]); 1612 NewMask1.push_back(Mask1[M]); 1613 } 1614 } 1615 1616 // Try to merge shuffles across the binop if the new shuffles are not costly. 1617 InstructionCost OldCost = 1618 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) + 1619 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy, 1620 OuterMask, CostKind, 0, nullptr, {BinOp}, &I) + 1621 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0, 1622 CostKind, 0, nullptr, {Op00, Op01}, 1623 cast<Instruction>(BinOp->getOperand(0))) + 1624 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1, 1625 CostKind, 0, nullptr, {Op10, Op11}, 1626 cast<Instruction>(BinOp->getOperand(1))); 1627 1628 InstructionCost NewCost = 1629 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0, 1630 CostKind, 0, nullptr, {Op00, Op01}) + 1631 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1, 1632 CostKind, 0, nullptr, {Op10, Op11}) + 1633 TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind); 1634 1635 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I 1636 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1637 << "\n"); 1638 1639 // If costs are equal, still fold as we reduce instruction count. 1640 if (NewCost > OldCost) 1641 return false; 1642 1643 Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0); 1644 Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1); 1645 Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); 1646 1647 // Intersect flags from the old binops. 1648 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) 1649 NewInst->copyIRFlags(BinOp); 1650 1651 Worklist.pushValue(Shuf0); 1652 Worklist.pushValue(Shuf1); 1653 replaceValue(I, *NewBO); 1654 return true; 1655 } 1656 1657 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)". 1658 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)". 1659 bool VectorCombine::foldShuffleOfBinops(Instruction &I) { 1660 ArrayRef<int> OldMask; 1661 Instruction *LHS, *RHS; 1662 if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)), 1663 m_OneUse(m_Instruction(RHS)), m_Mask(OldMask)))) 1664 return false; 1665 1666 // TODO: Add support for addlike etc. 1667 if (LHS->getOpcode() != RHS->getOpcode()) 1668 return false; 1669 1670 Value *X, *Y, *Z, *W; 1671 bool IsCommutative = false; 1672 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE; 1673 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE; 1674 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) && 1675 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) { 1676 auto *BO = cast<BinaryOperator>(LHS); 1677 // Don't introduce poison into div/rem. 1678 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem()) 1679 return false; 1680 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode()); 1681 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) && 1682 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) && 1683 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) { 1684 IsCommutative = cast<CmpInst>(LHS)->isCommutative(); 1685 } else 1686 return false; 1687 1688 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1689 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType()); 1690 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType()); 1691 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType()) 1692 return false; 1693 1694 unsigned NumSrcElts = BinOpTy->getNumElements(); 1695 1696 // If we have something like "add X, Y" and "add Z, X", swap ops to match. 1697 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z)) 1698 std::swap(X, Y); 1699 1700 auto ConvertToUnary = [NumSrcElts](int &M) { 1701 if (M >= (int)NumSrcElts) 1702 M -= NumSrcElts; 1703 }; 1704 1705 SmallVector<int> NewMask0(OldMask); 1706 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc; 1707 if (X == Z) { 1708 llvm::for_each(NewMask0, ConvertToUnary); 1709 SK0 = TargetTransformInfo::SK_PermuteSingleSrc; 1710 Z = PoisonValue::get(BinOpTy); 1711 } 1712 1713 SmallVector<int> NewMask1(OldMask); 1714 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc; 1715 if (Y == W) { 1716 llvm::for_each(NewMask1, ConvertToUnary); 1717 SK1 = TargetTransformInfo::SK_PermuteSingleSrc; 1718 W = PoisonValue::get(BinOpTy); 1719 } 1720 1721 // Try to replace a binop with a shuffle if the shuffle is not costly. 1722 InstructionCost OldCost = 1723 TTI.getInstructionCost(LHS, CostKind) + 1724 TTI.getInstructionCost(RHS, CostKind) + 1725 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy, 1726 OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I); 1727 1728 InstructionCost NewCost = 1729 TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) + 1730 TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}); 1731 1732 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) { 1733 NewCost += 1734 TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind); 1735 } else { 1736 auto *ShuffleCmpTy = 1737 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy); 1738 NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, 1739 ShuffleDstTy, PredLHS, CostKind); 1740 } 1741 1742 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I 1743 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1744 << "\n"); 1745 1746 // If either shuffle will constant fold away, then fold for the same cost as 1747 // we will reduce the instruction count. 1748 bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) || 1749 (isa<Constant>(Y) && isa<Constant>(W)); 1750 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost)) 1751 return false; 1752 1753 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0); 1754 Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1); 1755 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE 1756 ? Builder.CreateBinOp( 1757 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1) 1758 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1); 1759 1760 // Intersect flags from the old binops. 1761 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { 1762 NewInst->copyIRFlags(LHS); 1763 NewInst->andIRFlags(RHS); 1764 } 1765 1766 Worklist.pushValue(Shuf0); 1767 Worklist.pushValue(Shuf1); 1768 replaceValue(I, *NewBO); 1769 return true; 1770 } 1771 1772 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand 1773 /// into "castop (shuffle)". 1774 bool VectorCombine::foldShuffleOfCastops(Instruction &I) { 1775 Value *V0, *V1; 1776 ArrayRef<int> OldMask; 1777 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask)))) 1778 return false; 1779 1780 auto *C0 = dyn_cast<CastInst>(V0); 1781 auto *C1 = dyn_cast<CastInst>(V1); 1782 if (!C0 || !C1) 1783 return false; 1784 1785 Instruction::CastOps Opcode = C0->getOpcode(); 1786 if (C0->getSrcTy() != C1->getSrcTy()) 1787 return false; 1788 1789 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds. 1790 if (Opcode != C1->getOpcode()) { 1791 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value()))) 1792 Opcode = Instruction::SExt; 1793 else 1794 return false; 1795 } 1796 1797 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1798 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy()); 1799 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy()); 1800 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy) 1801 return false; 1802 1803 unsigned NumSrcElts = CastSrcTy->getNumElements(); 1804 unsigned NumDstElts = CastDstTy->getNumElements(); 1805 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) && 1806 "Only bitcasts expected to alter src/dst element counts"); 1807 1808 // Check for bitcasting of unscalable vector types. 1809 // e.g. <32 x i40> -> <40 x i32> 1810 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 && 1811 (NumDstElts % NumSrcElts) != 0) 1812 return false; 1813 1814 SmallVector<int, 16> NewMask; 1815 if (NumSrcElts >= NumDstElts) { 1816 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 1817 // always be expanded to the equivalent form choosing narrower elements. 1818 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask"); 1819 unsigned ScaleFactor = NumSrcElts / NumDstElts; 1820 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask); 1821 } else { 1822 // The bitcast is from narrow elements to wide elements. The shuffle mask 1823 // must choose consecutive elements to allow casting first. 1824 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask"); 1825 unsigned ScaleFactor = NumDstElts / NumSrcElts; 1826 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask)) 1827 return false; 1828 } 1829 1830 auto *NewShuffleDstTy = 1831 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size()); 1832 1833 // Try to replace a castop with a shuffle if the shuffle is not costly. 1834 InstructionCost CostC0 = 1835 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy, 1836 TTI::CastContextHint::None, CostKind); 1837 InstructionCost CostC1 = 1838 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy, 1839 TTI::CastContextHint::None, CostKind); 1840 InstructionCost OldCost = CostC0 + CostC1; 1841 OldCost += 1842 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, 1843 OldMask, CostKind, 0, nullptr, {}, &I); 1844 1845 InstructionCost NewCost = TTI.getShuffleCost( 1846 TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind); 1847 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy, 1848 TTI::CastContextHint::None, CostKind); 1849 if (!C0->hasOneUse()) 1850 NewCost += CostC0; 1851 if (!C1->hasOneUse()) 1852 NewCost += CostC1; 1853 1854 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I 1855 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1856 << "\n"); 1857 if (NewCost > OldCost) 1858 return false; 1859 1860 Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0), 1861 C1->getOperand(0), NewMask); 1862 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy); 1863 1864 // Intersect flags from the old casts. 1865 if (auto *NewInst = dyn_cast<Instruction>(Cast)) { 1866 NewInst->copyIRFlags(C0); 1867 NewInst->andIRFlags(C1); 1868 } 1869 1870 Worklist.pushValue(Shuf); 1871 replaceValue(I, *Cast); 1872 return true; 1873 } 1874 1875 /// Try to convert any of: 1876 /// "shuffle (shuffle x, y), (shuffle y, x)" 1877 /// "shuffle (shuffle x, undef), (shuffle y, undef)" 1878 /// "shuffle (shuffle x, undef), y" 1879 /// "shuffle x, (shuffle y, undef)" 1880 /// into "shuffle x, y". 1881 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { 1882 ArrayRef<int> OuterMask; 1883 Value *OuterV0, *OuterV1; 1884 if (!match(&I, 1885 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask)))) 1886 return false; 1887 1888 ArrayRef<int> InnerMask0, InnerMask1; 1889 Value *X0, *X1, *Y0, *Y1; 1890 bool Match0 = 1891 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0))); 1892 bool Match1 = 1893 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1))); 1894 if (!Match0 && !Match1) 1895 return false; 1896 1897 X0 = Match0 ? X0 : OuterV0; 1898 Y0 = Match0 ? Y0 : OuterV0; 1899 X1 = Match1 ? X1 : OuterV1; 1900 Y1 = Match1 ? Y1 : OuterV1; 1901 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1902 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType()); 1903 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType()); 1904 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy || 1905 X0->getType() != X1->getType()) 1906 return false; 1907 1908 unsigned NumSrcElts = ShuffleSrcTy->getNumElements(); 1909 unsigned NumImmElts = ShuffleImmTy->getNumElements(); 1910 1911 // Attempt to merge shuffles, matching upto 2 source operands. 1912 // Replace index to a poison arg with PoisonMaskElem. 1913 // Bail if either inner masks reference an undef arg. 1914 SmallVector<int, 16> NewMask(OuterMask); 1915 Value *NewX = nullptr, *NewY = nullptr; 1916 for (int &M : NewMask) { 1917 Value *Src = nullptr; 1918 if (0 <= M && M < (int)NumImmElts) { 1919 Src = OuterV0; 1920 if (Match0) { 1921 M = InnerMask0[M]; 1922 Src = M >= (int)NumSrcElts ? Y0 : X0; 1923 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; 1924 } 1925 } else if (M >= (int)NumImmElts) { 1926 Src = OuterV1; 1927 M -= NumImmElts; 1928 if (Match1) { 1929 M = InnerMask1[M]; 1930 Src = M >= (int)NumSrcElts ? Y1 : X1; 1931 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M; 1932 } 1933 } 1934 if (Src && M != PoisonMaskElem) { 1935 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index"); 1936 if (isa<UndefValue>(Src)) { 1937 // We've referenced an undef element - if its poison, update the shuffle 1938 // mask, else bail. 1939 if (!isa<PoisonValue>(Src)) 1940 return false; 1941 M = PoisonMaskElem; 1942 continue; 1943 } 1944 if (!NewX || NewX == Src) { 1945 NewX = Src; 1946 continue; 1947 } 1948 if (!NewY || NewY == Src) { 1949 M += NumSrcElts; 1950 NewY = Src; 1951 continue; 1952 } 1953 return false; 1954 } 1955 } 1956 1957 if (!NewX) 1958 return PoisonValue::get(ShuffleDstTy); 1959 if (!NewY) 1960 NewY = PoisonValue::get(ShuffleSrcTy); 1961 1962 // Have we folded to an Identity shuffle? 1963 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) { 1964 replaceValue(I, *NewX); 1965 return true; 1966 } 1967 1968 // Try to merge the shuffles if the new shuffle is not costly. 1969 InstructionCost InnerCost0 = 0; 1970 if (Match0) 1971 InnerCost0 = TTI.getInstructionCost(cast<Instruction>(OuterV0), CostKind); 1972 1973 InstructionCost InnerCost1 = 0; 1974 if (Match1) 1975 InnerCost1 = TTI.getInstructionCost(cast<Instruction>(OuterV1), CostKind); 1976 1977 InstructionCost OuterCost = TTI.getShuffleCost( 1978 TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, 1979 0, nullptr, {OuterV0, OuterV1}, &I); 1980 1981 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; 1982 1983 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; }); 1984 TargetTransformInfo::ShuffleKind SK = 1985 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc 1986 : TargetTransformInfo::SK_PermuteTwoSrc; 1987 InstructionCost NewCost = TTI.getShuffleCost( 1988 SK, ShuffleSrcTy, NewMask, CostKind, 0, nullptr, {NewX, NewY}); 1989 if (!OuterV0->hasOneUse()) 1990 NewCost += InnerCost0; 1991 if (!OuterV1->hasOneUse()) 1992 NewCost += InnerCost1; 1993 1994 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I 1995 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1996 << "\n"); 1997 if (NewCost > OldCost) 1998 return false; 1999 2000 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask); 2001 replaceValue(I, *Shuf); 2002 return true; 2003 } 2004 2005 /// Try to convert 2006 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)". 2007 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) { 2008 Value *V0, *V1; 2009 ArrayRef<int> OldMask; 2010 if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)), 2011 m_Mask(OldMask)))) 2012 return false; 2013 2014 auto *II0 = dyn_cast<IntrinsicInst>(V0); 2015 auto *II1 = dyn_cast<IntrinsicInst>(V1); 2016 if (!II0 || !II1) 2017 return false; 2018 2019 Intrinsic::ID IID = II0->getIntrinsicID(); 2020 if (IID != II1->getIntrinsicID()) 2021 return false; 2022 2023 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 2024 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType()); 2025 if (!ShuffleDstTy || !II0Ty) 2026 return false; 2027 2028 if (!isTriviallyVectorizable(IID)) 2029 return false; 2030 2031 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2032 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) && 2033 II0->getArgOperand(I) != II1->getArgOperand(I)) 2034 return false; 2035 2036 InstructionCost OldCost = 2037 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) + 2038 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) + 2039 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask, 2040 CostKind, 0, nullptr, {II0, II1}, &I); 2041 2042 SmallVector<Type *> NewArgsTy; 2043 InstructionCost NewCost = 0; 2044 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2045 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2046 NewArgsTy.push_back(II0->getArgOperand(I)->getType()); 2047 } else { 2048 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType()); 2049 NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(), 2050 VecTy->getNumElements() * 2)); 2051 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, 2052 VecTy, OldMask, CostKind); 2053 } 2054 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy); 2055 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind); 2056 2057 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I 2058 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 2059 << "\n"); 2060 2061 if (NewCost > OldCost) 2062 return false; 2063 2064 SmallVector<Value *> NewArgs; 2065 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2066 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2067 NewArgs.push_back(II0->getArgOperand(I)); 2068 } else { 2069 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), 2070 II1->getArgOperand(I), OldMask); 2071 NewArgs.push_back(Shuf); 2072 Worklist.pushValue(Shuf); 2073 } 2074 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs); 2075 2076 // Intersect flags from the old intrinsics. 2077 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) { 2078 NewInst->copyIRFlags(II0); 2079 NewInst->andIRFlags(II1); 2080 } 2081 2082 replaceValue(I, *NewIntrinsic); 2083 return true; 2084 } 2085 2086 using InstLane = std::pair<Use *, int>; 2087 2088 static InstLane lookThroughShuffles(Use *U, int Lane) { 2089 while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) { 2090 unsigned NumElts = 2091 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements(); 2092 int M = SV->getMaskValue(Lane); 2093 if (M < 0) 2094 return {nullptr, PoisonMaskElem}; 2095 if (static_cast<unsigned>(M) < NumElts) { 2096 U = &SV->getOperandUse(0); 2097 Lane = M; 2098 } else { 2099 U = &SV->getOperandUse(1); 2100 Lane = M - NumElts; 2101 } 2102 } 2103 return InstLane{U, Lane}; 2104 } 2105 2106 static SmallVector<InstLane> 2107 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { 2108 SmallVector<InstLane> NItem; 2109 for (InstLane IL : Item) { 2110 auto [U, Lane] = IL; 2111 InstLane OpLane = 2112 U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op), 2113 Lane) 2114 : InstLane{nullptr, PoisonMaskElem}; 2115 NItem.emplace_back(OpLane); 2116 } 2117 return NItem; 2118 } 2119 2120 /// Detect concat of multiple values into a vector 2121 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind, 2122 const TargetTransformInfo &TTI) { 2123 auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); 2124 unsigned NumElts = Ty->getNumElements(); 2125 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) 2126 return false; 2127 2128 // Check that the concat is free, usually meaning that the type will be split 2129 // during legalization. 2130 SmallVector<int, 16> ConcatMask(NumElts * 2); 2131 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 2132 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0) 2133 return false; 2134 2135 unsigned NumSlices = Item.size() / NumElts; 2136 // Currently we generate a tree of shuffles for the concats, which limits us 2137 // to a power2. 2138 if (!isPowerOf2_32(NumSlices)) 2139 return false; 2140 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { 2141 Use *SliceV = Item[Slice * NumElts].first; 2142 if (!SliceV || SliceV->get()->getType() != Ty) 2143 return false; 2144 for (unsigned Elt = 0; Elt < NumElts; ++Elt) { 2145 auto [V, Lane] = Item[Slice * NumElts + Elt]; 2146 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) 2147 return false; 2148 } 2149 } 2150 return true; 2151 } 2152 2153 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, 2154 const SmallPtrSet<Use *, 4> &IdentityLeafs, 2155 const SmallPtrSet<Use *, 4> &SplatLeafs, 2156 const SmallPtrSet<Use *, 4> &ConcatLeafs, 2157 IRBuilder<> &Builder, 2158 const TargetTransformInfo *TTI) { 2159 auto [FrontU, FrontLane] = Item.front(); 2160 2161 if (IdentityLeafs.contains(FrontU)) { 2162 return FrontU->get(); 2163 } 2164 if (SplatLeafs.contains(FrontU)) { 2165 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); 2166 return Builder.CreateShuffleVector(FrontU->get(), Mask); 2167 } 2168 if (ConcatLeafs.contains(FrontU)) { 2169 unsigned NumElts = 2170 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); 2171 SmallVector<Value *> Values(Item.size() / NumElts, nullptr); 2172 for (unsigned S = 0; S < Values.size(); ++S) 2173 Values[S] = Item[S * NumElts].first->get(); 2174 2175 while (Values.size() > 1) { 2176 NumElts *= 2; 2177 SmallVector<int, 16> Mask(NumElts, 0); 2178 std::iota(Mask.begin(), Mask.end(), 0); 2179 SmallVector<Value *> NewValues(Values.size() / 2, nullptr); 2180 for (unsigned S = 0; S < NewValues.size(); ++S) 2181 NewValues[S] = 2182 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); 2183 Values = NewValues; 2184 } 2185 return Values[0]; 2186 } 2187 2188 auto *I = cast<Instruction>(FrontU->get()); 2189 auto *II = dyn_cast<IntrinsicInst>(I); 2190 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0); 2191 SmallVector<Value *> Ops(NumOps); 2192 for (unsigned Idx = 0; Idx < NumOps; Idx++) { 2193 if (II && 2194 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) { 2195 Ops[Idx] = II->getOperand(Idx); 2196 continue; 2197 } 2198 Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), 2199 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs, 2200 Builder, TTI); 2201 } 2202 2203 SmallVector<Value *, 8> ValueList; 2204 for (const auto &Lane : Item) 2205 if (Lane.first) 2206 ValueList.push_back(Lane.first->get()); 2207 2208 Type *DstTy = 2209 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); 2210 if (auto *BI = dyn_cast<BinaryOperator>(I)) { 2211 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), 2212 Ops[0], Ops[1]); 2213 propagateIRFlags(Value, ValueList); 2214 return Value; 2215 } 2216 if (auto *CI = dyn_cast<CmpInst>(I)) { 2217 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]); 2218 propagateIRFlags(Value, ValueList); 2219 return Value; 2220 } 2221 if (auto *SI = dyn_cast<SelectInst>(I)) { 2222 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI); 2223 propagateIRFlags(Value, ValueList); 2224 return Value; 2225 } 2226 if (auto *CI = dyn_cast<CastInst>(I)) { 2227 auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(), 2228 Ops[0], DstTy); 2229 propagateIRFlags(Value, ValueList); 2230 return Value; 2231 } 2232 if (II) { 2233 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops); 2234 propagateIRFlags(Value, ValueList); 2235 return Value; 2236 } 2237 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate"); 2238 auto *Value = 2239 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]); 2240 propagateIRFlags(Value, ValueList); 2241 return Value; 2242 } 2243 2244 // Starting from a shuffle, look up through operands tracking the shuffled index 2245 // of each lane. If we can simplify away the shuffles to identities then 2246 // do so. 2247 bool VectorCombine::foldShuffleToIdentity(Instruction &I) { 2248 auto *Ty = dyn_cast<FixedVectorType>(I.getType()); 2249 if (!Ty || I.use_empty()) 2250 return false; 2251 2252 SmallVector<InstLane> Start(Ty->getNumElements()); 2253 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M) 2254 Start[M] = lookThroughShuffles(&*I.use_begin(), M); 2255 2256 SmallVector<SmallVector<InstLane>> Worklist; 2257 Worklist.push_back(Start); 2258 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; 2259 unsigned NumVisited = 0; 2260 2261 while (!Worklist.empty()) { 2262 if (++NumVisited > MaxInstrsToScan) 2263 return false; 2264 2265 SmallVector<InstLane> Item = Worklist.pop_back_val(); 2266 auto [FrontU, FrontLane] = Item.front(); 2267 2268 // If we found an undef first lane then bail out to keep things simple. 2269 if (!FrontU) 2270 return false; 2271 2272 // Helper to peek through bitcasts to the same value. 2273 auto IsEquiv = [&](Value *X, Value *Y) { 2274 return X->getType() == Y->getType() && 2275 peekThroughBitcasts(X) == peekThroughBitcasts(Y); 2276 }; 2277 2278 // Look for an identity value. 2279 if (FrontLane == 0 && 2280 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() == 2281 Ty->getNumElements() && 2282 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) { 2283 Value *FrontV = Item.front().first->get(); 2284 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) && 2285 E.value().second == (int)E.index()); 2286 })) { 2287 IdentityLeafs.insert(FrontU); 2288 continue; 2289 } 2290 // Look for constants, for the moment only supporting constant splats. 2291 if (auto *C = dyn_cast<Constant>(FrontU); 2292 C && C->getSplatValue() && 2293 all_of(drop_begin(Item), [Item](InstLane &IL) { 2294 Value *FrontV = Item.front().first->get(); 2295 Use *U = IL.first; 2296 return !U || (isa<Constant>(U->get()) && 2297 cast<Constant>(U->get())->getSplatValue() == 2298 cast<Constant>(FrontV)->getSplatValue()); 2299 })) { 2300 SplatLeafs.insert(FrontU); 2301 continue; 2302 } 2303 // Look for a splat value. 2304 if (all_of(drop_begin(Item), [Item](InstLane &IL) { 2305 auto [FrontU, FrontLane] = Item.front(); 2306 auto [U, Lane] = IL; 2307 return !U || (U->get() == FrontU->get() && Lane == FrontLane); 2308 })) { 2309 SplatLeafs.insert(FrontU); 2310 continue; 2311 } 2312 2313 // We need each element to be the same type of value, and check that each 2314 // element has a single use. 2315 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) { 2316 Value *FrontV = Item.front().first->get(); 2317 if (!IL.first) 2318 return true; 2319 Value *V = IL.first->get(); 2320 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse()) 2321 return false; 2322 if (V->getValueID() != FrontV->getValueID()) 2323 return false; 2324 if (auto *CI = dyn_cast<CmpInst>(V)) 2325 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate()) 2326 return false; 2327 if (auto *CI = dyn_cast<CastInst>(V)) 2328 if (CI->getSrcTy()->getScalarType() != 2329 cast<CastInst>(FrontV)->getSrcTy()->getScalarType()) 2330 return false; 2331 if (auto *SI = dyn_cast<SelectInst>(V)) 2332 if (!isa<VectorType>(SI->getOperand(0)->getType()) || 2333 SI->getOperand(0)->getType() != 2334 cast<SelectInst>(FrontV)->getOperand(0)->getType()) 2335 return false; 2336 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V)) 2337 return false; 2338 auto *II = dyn_cast<IntrinsicInst>(V); 2339 return !II || (isa<IntrinsicInst>(FrontV) && 2340 II->getIntrinsicID() == 2341 cast<IntrinsicInst>(FrontV)->getIntrinsicID() && 2342 !II->hasOperandBundles()); 2343 }; 2344 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) { 2345 // Check the operator is one that we support. 2346 if (isa<BinaryOperator, CmpInst>(FrontU)) { 2347 // We exclude div/rem in case they hit UB from poison lanes. 2348 if (auto *BO = dyn_cast<BinaryOperator>(FrontU); 2349 BO && BO->isIntDivRem()) 2350 return false; 2351 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2352 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2353 continue; 2354 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst, 2355 FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) { 2356 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2357 continue; 2358 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) { 2359 // TODO: Handle vector widening/narrowing bitcasts. 2360 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy()); 2361 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy()); 2362 if (DstTy && SrcTy && 2363 SrcTy->getNumElements() == DstTy->getNumElements()) { 2364 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2365 continue; 2366 } 2367 } else if (isa<SelectInst>(FrontU)) { 2368 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2369 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2370 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); 2371 continue; 2372 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); 2373 II && isTriviallyVectorizable(II->getIntrinsicID()) && 2374 !II->hasOperandBundles()) { 2375 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { 2376 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op, 2377 &TTI)) { 2378 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { 2379 Value *FrontV = Item.front().first->get(); 2380 Use *U = IL.first; 2381 return !U || (cast<Instruction>(U->get())->getOperand(Op) == 2382 cast<Instruction>(FrontV)->getOperand(Op)); 2383 })) 2384 return false; 2385 continue; 2386 } 2387 Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); 2388 } 2389 continue; 2390 } 2391 } 2392 2393 if (isFreeConcat(Item, CostKind, TTI)) { 2394 ConcatLeafs.insert(FrontU); 2395 continue; 2396 } 2397 2398 return false; 2399 } 2400 2401 if (NumVisited <= 1) 2402 return false; 2403 2404 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n"); 2405 2406 // If we got this far, we know the shuffles are superfluous and can be 2407 // removed. Scan through again and generate the new tree of instructions. 2408 Builder.SetInsertPoint(&I); 2409 Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, 2410 ConcatLeafs, Builder, &TTI); 2411 replaceValue(I, *V); 2412 return true; 2413 } 2414 2415 /// Given a commutative reduction, the order of the input lanes does not alter 2416 /// the results. We can use this to remove certain shuffles feeding the 2417 /// reduction, removing the need to shuffle at all. 2418 bool VectorCombine::foldShuffleFromReductions(Instruction &I) { 2419 auto *II = dyn_cast<IntrinsicInst>(&I); 2420 if (!II) 2421 return false; 2422 switch (II->getIntrinsicID()) { 2423 case Intrinsic::vector_reduce_add: 2424 case Intrinsic::vector_reduce_mul: 2425 case Intrinsic::vector_reduce_and: 2426 case Intrinsic::vector_reduce_or: 2427 case Intrinsic::vector_reduce_xor: 2428 case Intrinsic::vector_reduce_smin: 2429 case Intrinsic::vector_reduce_smax: 2430 case Intrinsic::vector_reduce_umin: 2431 case Intrinsic::vector_reduce_umax: 2432 break; 2433 default: 2434 return false; 2435 } 2436 2437 // Find all the inputs when looking through operations that do not alter the 2438 // lane order (binops, for example). Currently we look for a single shuffle, 2439 // and can ignore splat values. 2440 std::queue<Value *> Worklist; 2441 SmallPtrSet<Value *, 4> Visited; 2442 ShuffleVectorInst *Shuffle = nullptr; 2443 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0))) 2444 Worklist.push(Op); 2445 2446 while (!Worklist.empty()) { 2447 Value *CV = Worklist.front(); 2448 Worklist.pop(); 2449 if (Visited.contains(CV)) 2450 continue; 2451 2452 // Splats don't change the order, so can be safely ignored. 2453 if (isSplatValue(CV)) 2454 continue; 2455 2456 Visited.insert(CV); 2457 2458 if (auto *CI = dyn_cast<Instruction>(CV)) { 2459 if (CI->isBinaryOp()) { 2460 for (auto *Op : CI->operand_values()) 2461 Worklist.push(Op); 2462 continue; 2463 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) { 2464 if (Shuffle && Shuffle != SV) 2465 return false; 2466 Shuffle = SV; 2467 continue; 2468 } 2469 } 2470 2471 // Anything else is currently an unknown node. 2472 return false; 2473 } 2474 2475 if (!Shuffle) 2476 return false; 2477 2478 // Check all uses of the binary ops and shuffles are also included in the 2479 // lane-invariant operations (Visited should be the list of lanewise 2480 // instructions, including the shuffle that we found). 2481 for (auto *V : Visited) 2482 for (auto *U : V->users()) 2483 if (!Visited.contains(U) && U != &I) 2484 return false; 2485 2486 FixedVectorType *VecType = 2487 dyn_cast<FixedVectorType>(II->getOperand(0)->getType()); 2488 if (!VecType) 2489 return false; 2490 FixedVectorType *ShuffleInputType = 2491 dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType()); 2492 if (!ShuffleInputType) 2493 return false; 2494 unsigned NumInputElts = ShuffleInputType->getNumElements(); 2495 2496 // Find the mask from sorting the lanes into order. This is most likely to 2497 // become a identity or concat mask. Undef elements are pushed to the end. 2498 SmallVector<int> ConcatMask; 2499 Shuffle->getShuffleMask(ConcatMask); 2500 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; }); 2501 // In the case of a truncating shuffle it's possible for the mask 2502 // to have an index greater than the size of the resulting vector. 2503 // This requires special handling. 2504 bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts; 2505 bool UsesSecondVec = 2506 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; }); 2507 2508 FixedVectorType *VecTyForCost = 2509 (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType; 2510 InstructionCost OldCost = TTI.getShuffleCost( 2511 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2512 VecTyForCost, Shuffle->getShuffleMask(), CostKind); 2513 InstructionCost NewCost = TTI.getShuffleCost( 2514 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2515 VecTyForCost, ConcatMask, CostKind); 2516 2517 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle 2518 << "\n"); 2519 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost 2520 << "\n"); 2521 if (NewCost < OldCost) { 2522 Builder.SetInsertPoint(Shuffle); 2523 Value *NewShuffle = Builder.CreateShuffleVector( 2524 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); 2525 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); 2526 replaceValue(*Shuffle, *NewShuffle); 2527 } 2528 2529 // See if we can re-use foldSelectShuffle, getting it to reduce the size of 2530 // the shuffle into a nicer order, as it can ignore the order of the shuffles. 2531 return foldSelectShuffle(*Shuffle, true); 2532 } 2533 2534 /// Determine if its more efficient to fold: 2535 /// reduce(trunc(x)) -> trunc(reduce(x)). 2536 /// reduce(sext(x)) -> sext(reduce(x)). 2537 /// reduce(zext(x)) -> zext(reduce(x)). 2538 bool VectorCombine::foldCastFromReductions(Instruction &I) { 2539 auto *II = dyn_cast<IntrinsicInst>(&I); 2540 if (!II) 2541 return false; 2542 2543 bool TruncOnly = false; 2544 Intrinsic::ID IID = II->getIntrinsicID(); 2545 switch (IID) { 2546 case Intrinsic::vector_reduce_add: 2547 case Intrinsic::vector_reduce_mul: 2548 TruncOnly = true; 2549 break; 2550 case Intrinsic::vector_reduce_and: 2551 case Intrinsic::vector_reduce_or: 2552 case Intrinsic::vector_reduce_xor: 2553 break; 2554 default: 2555 return false; 2556 } 2557 2558 unsigned ReductionOpc = getArithmeticReductionInstruction(IID); 2559 Value *ReductionSrc = I.getOperand(0); 2560 2561 Value *Src; 2562 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) && 2563 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src)))))) 2564 return false; 2565 2566 auto CastOpc = 2567 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode(); 2568 2569 auto *SrcTy = cast<VectorType>(Src->getType()); 2570 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType()); 2571 Type *ResultTy = I.getType(); 2572 2573 InstructionCost OldCost = TTI.getArithmeticReductionCost( 2574 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind); 2575 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy, 2576 TTI::CastContextHint::None, CostKind, 2577 cast<CastInst>(ReductionSrc)); 2578 InstructionCost NewCost = 2579 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt, 2580 CostKind) + 2581 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(), 2582 TTI::CastContextHint::None, CostKind); 2583 2584 if (OldCost <= NewCost || !NewCost.isValid()) 2585 return false; 2586 2587 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(), 2588 II->getIntrinsicID(), {Src}); 2589 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy); 2590 replaceValue(I, *NewCast); 2591 return true; 2592 } 2593 2594 /// This method looks for groups of shuffles acting on binops, of the form: 2595 /// %x = shuffle ... 2596 /// %y = shuffle ... 2597 /// %a = binop %x, %y 2598 /// %b = binop %x, %y 2599 /// shuffle %a, %b, selectmask 2600 /// We may, especially if the shuffle is wider than legal, be able to convert 2601 /// the shuffle to a form where only parts of a and b need to be computed. On 2602 /// architectures with no obvious "select" shuffle, this can reduce the total 2603 /// number of operations if the target reports them as cheaper. 2604 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { 2605 auto *SVI = cast<ShuffleVectorInst>(&I); 2606 auto *VT = cast<FixedVectorType>(I.getType()); 2607 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0)); 2608 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1)); 2609 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() || 2610 VT != Op0->getType()) 2611 return false; 2612 2613 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0)); 2614 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1)); 2615 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0)); 2616 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1)); 2617 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B}); 2618 auto checkSVNonOpUses = [&](Instruction *I) { 2619 if (!I || I->getOperand(0)->getType() != VT) 2620 return true; 2621 return any_of(I->users(), [&](User *U) { 2622 return U != Op0 && U != Op1 && 2623 !(isa<ShuffleVectorInst>(U) && 2624 (InputShuffles.contains(cast<Instruction>(U)) || 2625 isInstructionTriviallyDead(cast<Instruction>(U)))); 2626 }); 2627 }; 2628 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) || 2629 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B)) 2630 return false; 2631 2632 // Collect all the uses that are shuffles that we can transform together. We 2633 // may not have a single shuffle, but a group that can all be transformed 2634 // together profitably. 2635 SmallVector<ShuffleVectorInst *> Shuffles; 2636 auto collectShuffles = [&](Instruction *I) { 2637 for (auto *U : I->users()) { 2638 auto *SV = dyn_cast<ShuffleVectorInst>(U); 2639 if (!SV || SV->getType() != VT) 2640 return false; 2641 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) || 2642 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1)) 2643 return false; 2644 if (!llvm::is_contained(Shuffles, SV)) 2645 Shuffles.push_back(SV); 2646 } 2647 return true; 2648 }; 2649 if (!collectShuffles(Op0) || !collectShuffles(Op1)) 2650 return false; 2651 // From a reduction, we need to be processing a single shuffle, otherwise the 2652 // other uses will not be lane-invariant. 2653 if (FromReduction && Shuffles.size() > 1) 2654 return false; 2655 2656 // Add any shuffle uses for the shuffles we have found, to include them in our 2657 // cost calculations. 2658 if (!FromReduction) { 2659 for (ShuffleVectorInst *SV : Shuffles) { 2660 for (auto *U : SV->users()) { 2661 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U); 2662 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT) 2663 Shuffles.push_back(SSV); 2664 } 2665 } 2666 } 2667 2668 // For each of the output shuffles, we try to sort all the first vector 2669 // elements to the beginning, followed by the second array elements at the 2670 // end. If the binops are legalized to smaller vectors, this may reduce total 2671 // number of binops. We compute the ReconstructMask mask needed to convert 2672 // back to the original lane order. 2673 SmallVector<std::pair<int, int>> V1, V2; 2674 SmallVector<SmallVector<int>> OrigReconstructMasks; 2675 int MaxV1Elt = 0, MaxV2Elt = 0; 2676 unsigned NumElts = VT->getNumElements(); 2677 for (ShuffleVectorInst *SVN : Shuffles) { 2678 SmallVector<int> Mask; 2679 SVN->getShuffleMask(Mask); 2680 2681 // Check the operands are the same as the original, or reversed (in which 2682 // case we need to commute the mask). 2683 Value *SVOp0 = SVN->getOperand(0); 2684 Value *SVOp1 = SVN->getOperand(1); 2685 if (isa<UndefValue>(SVOp1)) { 2686 auto *SSV = cast<ShuffleVectorInst>(SVOp0); 2687 SVOp0 = SSV->getOperand(0); 2688 SVOp1 = SSV->getOperand(1); 2689 for (unsigned I = 0, E = Mask.size(); I != E; I++) { 2690 if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size())) 2691 return false; 2692 Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]); 2693 } 2694 } 2695 if (SVOp0 == Op1 && SVOp1 == Op0) { 2696 std::swap(SVOp0, SVOp1); 2697 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 2698 } 2699 if (SVOp0 != Op0 || SVOp1 != Op1) 2700 return false; 2701 2702 // Calculate the reconstruction mask for this shuffle, as the mask needed to 2703 // take the packed values from Op0/Op1 and reconstructing to the original 2704 // order. 2705 SmallVector<int> ReconstructMask; 2706 for (unsigned I = 0; I < Mask.size(); I++) { 2707 if (Mask[I] < 0) { 2708 ReconstructMask.push_back(-1); 2709 } else if (Mask[I] < static_cast<int>(NumElts)) { 2710 MaxV1Elt = std::max(MaxV1Elt, Mask[I]); 2711 auto It = find_if(V1, [&](const std::pair<int, int> &A) { 2712 return Mask[I] == A.first; 2713 }); 2714 if (It != V1.end()) 2715 ReconstructMask.push_back(It - V1.begin()); 2716 else { 2717 ReconstructMask.push_back(V1.size()); 2718 V1.emplace_back(Mask[I], V1.size()); 2719 } 2720 } else { 2721 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts); 2722 auto It = find_if(V2, [&](const std::pair<int, int> &A) { 2723 return Mask[I] - static_cast<int>(NumElts) == A.first; 2724 }); 2725 if (It != V2.end()) 2726 ReconstructMask.push_back(NumElts + It - V2.begin()); 2727 else { 2728 ReconstructMask.push_back(NumElts + V2.size()); 2729 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size()); 2730 } 2731 } 2732 } 2733 2734 // For reductions, we know that the lane ordering out doesn't alter the 2735 // result. In-order can help simplify the shuffle away. 2736 if (FromReduction) 2737 sort(ReconstructMask); 2738 OrigReconstructMasks.push_back(std::move(ReconstructMask)); 2739 } 2740 2741 // If the Maximum element used from V1 and V2 are not larger than the new 2742 // vectors, the vectors are already packes and performing the optimization 2743 // again will likely not help any further. This also prevents us from getting 2744 // stuck in a cycle in case the costs do not also rule it out. 2745 if (V1.empty() || V2.empty() || 2746 (MaxV1Elt == static_cast<int>(V1.size()) - 1 && 2747 MaxV2Elt == static_cast<int>(V2.size()) - 1)) 2748 return false; 2749 2750 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a 2751 // shuffle of another shuffle, or not a shuffle (that is treated like a 2752 // identity shuffle). 2753 auto GetBaseMaskValue = [&](Instruction *I, int M) { 2754 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2755 if (!SV) 2756 return M; 2757 if (isa<UndefValue>(SV->getOperand(1))) 2758 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2759 if (InputShuffles.contains(SSV)) 2760 return SSV->getMaskValue(SV->getMaskValue(M)); 2761 return SV->getMaskValue(M); 2762 }; 2763 2764 // Attempt to sort the inputs my ascending mask values to make simpler input 2765 // shuffles and push complex shuffles down to the uses. We sort on the first 2766 // of the two input shuffle orders, to try and get at least one input into a 2767 // nice order. 2768 auto SortBase = [&](Instruction *A, std::pair<int, int> X, 2769 std::pair<int, int> Y) { 2770 int MXA = GetBaseMaskValue(A, X.first); 2771 int MYA = GetBaseMaskValue(A, Y.first); 2772 return MXA < MYA; 2773 }; 2774 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) { 2775 return SortBase(SVI0A, A, B); 2776 }); 2777 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) { 2778 return SortBase(SVI1A, A, B); 2779 }); 2780 // Calculate our ReconstructMasks from the OrigReconstructMasks and the 2781 // modified order of the input shuffles. 2782 SmallVector<SmallVector<int>> ReconstructMasks; 2783 for (const auto &Mask : OrigReconstructMasks) { 2784 SmallVector<int> ReconstructMask; 2785 for (int M : Mask) { 2786 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) { 2787 auto It = find_if(V, [M](auto A) { return A.second == M; }); 2788 assert(It != V.end() && "Expected all entries in Mask"); 2789 return std::distance(V.begin(), It); 2790 }; 2791 if (M < 0) 2792 ReconstructMask.push_back(-1); 2793 else if (M < static_cast<int>(NumElts)) { 2794 ReconstructMask.push_back(FindIndex(V1, M)); 2795 } else { 2796 ReconstructMask.push_back(NumElts + FindIndex(V2, M)); 2797 } 2798 } 2799 ReconstructMasks.push_back(std::move(ReconstructMask)); 2800 } 2801 2802 // Calculate the masks needed for the new input shuffles, which get padded 2803 // with undef 2804 SmallVector<int> V1A, V1B, V2A, V2B; 2805 for (unsigned I = 0; I < V1.size(); I++) { 2806 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first)); 2807 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first)); 2808 } 2809 for (unsigned I = 0; I < V2.size(); I++) { 2810 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first)); 2811 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first)); 2812 } 2813 while (V1A.size() < NumElts) { 2814 V1A.push_back(PoisonMaskElem); 2815 V1B.push_back(PoisonMaskElem); 2816 } 2817 while (V2A.size() < NumElts) { 2818 V2A.push_back(PoisonMaskElem); 2819 V2B.push_back(PoisonMaskElem); 2820 } 2821 2822 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) { 2823 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2824 if (!SV) 2825 return C; 2826 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1)) 2827 ? TTI::SK_PermuteSingleSrc 2828 : TTI::SK_PermuteTwoSrc, 2829 VT, SV->getShuffleMask(), CostKind); 2830 }; 2831 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) { 2832 return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind); 2833 }; 2834 2835 // Get the costs of the shuffles + binops before and after with the new 2836 // shuffle masks. 2837 InstructionCost CostBefore = 2838 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) + 2839 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind); 2840 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(), 2841 InstructionCost(0), AddShuffleCost); 2842 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(), 2843 InstructionCost(0), AddShuffleCost); 2844 2845 // The new binops will be unused for lanes past the used shuffle lengths. 2846 // These types attempt to get the correct cost for that from the target. 2847 FixedVectorType *Op0SmallVT = 2848 FixedVectorType::get(VT->getScalarType(), V1.size()); 2849 FixedVectorType *Op1SmallVT = 2850 FixedVectorType::get(VT->getScalarType(), V2.size()); 2851 InstructionCost CostAfter = 2852 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) + 2853 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind); 2854 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(), 2855 InstructionCost(0), AddShuffleMaskCost); 2856 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B}); 2857 CostAfter += 2858 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(), 2859 InstructionCost(0), AddShuffleMaskCost); 2860 2861 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n"); 2862 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore 2863 << " vs CostAfter: " << CostAfter << "\n"); 2864 if (CostBefore <= CostAfter) 2865 return false; 2866 2867 // The cost model has passed, create the new instructions. 2868 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * { 2869 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2870 if (!SV) 2871 return I; 2872 if (isa<UndefValue>(SV->getOperand(1))) 2873 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2874 if (InputShuffles.contains(SSV)) 2875 return SSV->getOperand(Op); 2876 return SV->getOperand(Op); 2877 }; 2878 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef()); 2879 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0), 2880 GetShuffleOperand(SVI0A, 1), V1A); 2881 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef()); 2882 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0), 2883 GetShuffleOperand(SVI0B, 1), V1B); 2884 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef()); 2885 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0), 2886 GetShuffleOperand(SVI1A, 1), V2A); 2887 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef()); 2888 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0), 2889 GetShuffleOperand(SVI1B, 1), V2B); 2890 Builder.SetInsertPoint(Op0); 2891 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(), 2892 NSV0A, NSV0B); 2893 if (auto *I = dyn_cast<Instruction>(NOp0)) 2894 I->copyIRFlags(Op0, true); 2895 Builder.SetInsertPoint(Op1); 2896 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(), 2897 NSV1A, NSV1B); 2898 if (auto *I = dyn_cast<Instruction>(NOp1)) 2899 I->copyIRFlags(Op1, true); 2900 2901 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) { 2902 Builder.SetInsertPoint(Shuffles[S]); 2903 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]); 2904 replaceValue(*Shuffles[S], *NSV); 2905 } 2906 2907 Worklist.pushValue(NSV0A); 2908 Worklist.pushValue(NSV0B); 2909 Worklist.pushValue(NSV1A); 2910 Worklist.pushValue(NSV1B); 2911 for (auto *S : Shuffles) 2912 Worklist.add(S); 2913 return true; 2914 } 2915 2916 /// Check if instruction depends on ZExt and this ZExt can be moved after the 2917 /// instruction. Move ZExt if it is profitable. For example: 2918 /// logic(zext(x),y) -> zext(logic(x,trunc(y))) 2919 /// lshr((zext(x),y) -> zext(lshr(x,trunc(y))) 2920 /// Cost model calculations takes into account if zext(x) has other users and 2921 /// whether it can be propagated through them too. 2922 bool VectorCombine::shrinkType(Instruction &I) { 2923 Value *ZExted, *OtherOperand; 2924 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)), 2925 m_Value(OtherOperand))) && 2926 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) 2927 return false; 2928 2929 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0); 2930 2931 auto *BigTy = cast<FixedVectorType>(I.getType()); 2932 auto *SmallTy = cast<FixedVectorType>(ZExted->getType()); 2933 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits(); 2934 2935 if (I.getOpcode() == Instruction::LShr) { 2936 // Check that the shift amount is less than the number of bits in the 2937 // smaller type. Otherwise, the smaller lshr will return a poison value. 2938 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL); 2939 if (ShAmtKB.getMaxValue().uge(BW)) 2940 return false; 2941 } else { 2942 // Check that the expression overall uses at most the same number of bits as 2943 // ZExted 2944 KnownBits KB = computeKnownBits(&I, *DL); 2945 if (KB.countMaxActiveBits() > BW) 2946 return false; 2947 } 2948 2949 // Calculate costs of leaving current IR as it is and moving ZExt operation 2950 // later, along with adding truncates if needed 2951 InstructionCost ZExtCost = TTI.getCastInstrCost( 2952 Instruction::ZExt, BigTy, SmallTy, 2953 TargetTransformInfo::CastContextHint::None, CostKind); 2954 InstructionCost CurrentCost = ZExtCost; 2955 InstructionCost ShrinkCost = 0; 2956 2957 // Calculate total cost and check that we can propagate through all ZExt users 2958 for (User *U : ZExtOperand->users()) { 2959 auto *UI = cast<Instruction>(U); 2960 if (UI == &I) { 2961 CurrentCost += 2962 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2963 ShrinkCost += 2964 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2965 ShrinkCost += ZExtCost; 2966 continue; 2967 } 2968 2969 if (!Instruction::isBinaryOp(UI->getOpcode())) 2970 return false; 2971 2972 // Check if we can propagate ZExt through its other users 2973 KnownBits KB = computeKnownBits(UI, *DL); 2974 if (KB.countMaxActiveBits() > BW) 2975 return false; 2976 2977 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2978 ShrinkCost += 2979 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2980 ShrinkCost += ZExtCost; 2981 } 2982 2983 // If the other instruction operand is not a constant, we'll need to 2984 // generate a truncate instruction. So we have to adjust cost 2985 if (!isa<Constant>(OtherOperand)) 2986 ShrinkCost += TTI.getCastInstrCost( 2987 Instruction::Trunc, SmallTy, BigTy, 2988 TargetTransformInfo::CastContextHint::None, CostKind); 2989 2990 // If the cost of shrinking types and leaving the IR is the same, we'll lean 2991 // towards modifying the IR because shrinking opens opportunities for other 2992 // shrinking optimisations. 2993 if (ShrinkCost > CurrentCost) 2994 return false; 2995 2996 Builder.SetInsertPoint(&I); 2997 Value *Op0 = ZExted; 2998 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy); 2999 // Keep the order of operands the same 3000 if (I.getOperand(0) == OtherOperand) 3001 std::swap(Op0, Op1); 3002 Value *NewBinOp = 3003 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1); 3004 cast<Instruction>(NewBinOp)->copyIRFlags(&I); 3005 cast<Instruction>(NewBinOp)->copyMetadata(I); 3006 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy); 3007 replaceValue(I, *NewZExtr); 3008 return true; 3009 } 3010 3011 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) --> 3012 /// shuffle (DstVec, SrcVec, Mask) 3013 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) { 3014 Value *DstVec, *SrcVec; 3015 uint64_t ExtIdx, InsIdx; 3016 if (!match(&I, 3017 m_InsertElt(m_Value(DstVec), 3018 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)), 3019 m_ConstantInt(InsIdx)))) 3020 return false; 3021 3022 auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); 3023 if (!VecTy || SrcVec->getType() != VecTy) 3024 return false; 3025 3026 unsigned NumElts = VecTy->getNumElements(); 3027 if (ExtIdx >= NumElts || InsIdx >= NumElts) 3028 return false; 3029 3030 // Insertion into poison is a cheaper single operand shuffle. 3031 TargetTransformInfo::ShuffleKind SK; 3032 SmallVector<int> Mask(NumElts, PoisonMaskElem); 3033 if (isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec)) { 3034 SK = TargetTransformInfo::SK_PermuteSingleSrc; 3035 Mask[InsIdx] = ExtIdx; 3036 std::swap(DstVec, SrcVec); 3037 } else { 3038 SK = TargetTransformInfo::SK_PermuteTwoSrc; 3039 std::iota(Mask.begin(), Mask.end(), 0); 3040 Mask[InsIdx] = ExtIdx + NumElts; 3041 } 3042 3043 // Cost 3044 auto *Ins = cast<InsertElementInst>(&I); 3045 auto *Ext = cast<ExtractElementInst>(I.getOperand(1)); 3046 InstructionCost InsCost = 3047 TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx); 3048 InstructionCost ExtCost = 3049 TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx); 3050 InstructionCost OldCost = ExtCost + InsCost; 3051 3052 InstructionCost NewCost = TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0, 3053 nullptr, {DstVec, SrcVec}); 3054 if (!Ext->hasOneUse()) 3055 NewCost += ExtCost; 3056 3057 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I 3058 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 3059 << "\n"); 3060 3061 if (OldCost < NewCost) 3062 return false; 3063 3064 // Canonicalize undef param to RHS to help further folds. 3065 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) { 3066 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 3067 std::swap(DstVec, SrcVec); 3068 } 3069 3070 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask); 3071 replaceValue(I, *Shuf); 3072 3073 return true; 3074 } 3075 3076 /// This is the entry point for all transforms. Pass manager differences are 3077 /// handled in the callers of this function. 3078 bool VectorCombine::run() { 3079 if (DisableVectorCombine) 3080 return false; 3081 3082 // Don't attempt vectorization if the target does not support vectors. 3083 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true))) 3084 return false; 3085 3086 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n"); 3087 3088 bool MadeChange = false; 3089 auto FoldInst = [this, &MadeChange](Instruction &I) { 3090 Builder.SetInsertPoint(&I); 3091 bool IsVectorType = isa<VectorType>(I.getType()); 3092 bool IsFixedVectorType = isa<FixedVectorType>(I.getType()); 3093 auto Opcode = I.getOpcode(); 3094 3095 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n'); 3096 3097 // These folds should be beneficial regardless of when this pass is run 3098 // in the optimization pipeline. 3099 // The type checking is for run-time efficiency. We can avoid wasting time 3100 // dispatching to folding functions if there's no chance of matching. 3101 if (IsFixedVectorType) { 3102 switch (Opcode) { 3103 case Instruction::InsertElement: 3104 MadeChange |= vectorizeLoadInsert(I); 3105 break; 3106 case Instruction::ShuffleVector: 3107 MadeChange |= widenSubvectorLoad(I); 3108 break; 3109 default: 3110 break; 3111 } 3112 } 3113 3114 // This transform works with scalable and fixed vectors 3115 // TODO: Identify and allow other scalable transforms 3116 if (IsVectorType) { 3117 MadeChange |= scalarizeBinopOrCmp(I); 3118 MadeChange |= scalarizeLoadExtract(I); 3119 MadeChange |= scalarizeVPIntrinsic(I); 3120 } 3121 3122 if (Opcode == Instruction::Store) 3123 MadeChange |= foldSingleElementStore(I); 3124 3125 // If this is an early pipeline invocation of this pass, we are done. 3126 if (TryEarlyFoldsOnly) 3127 return; 3128 3129 // Otherwise, try folds that improve codegen but may interfere with 3130 // early IR canonicalizations. 3131 // The type checking is for run-time efficiency. We can avoid wasting time 3132 // dispatching to folding functions if there's no chance of matching. 3133 if (IsFixedVectorType) { 3134 switch (Opcode) { 3135 case Instruction::InsertElement: 3136 MadeChange |= foldInsExtFNeg(I); 3137 MadeChange |= foldInsExtVectorToShuffle(I); 3138 break; 3139 case Instruction::ShuffleVector: 3140 MadeChange |= foldPermuteOfBinops(I); 3141 MadeChange |= foldShuffleOfBinops(I); 3142 MadeChange |= foldShuffleOfCastops(I); 3143 MadeChange |= foldShuffleOfShuffles(I); 3144 MadeChange |= foldShuffleOfIntrinsics(I); 3145 MadeChange |= foldSelectShuffle(I); 3146 MadeChange |= foldShuffleToIdentity(I); 3147 break; 3148 case Instruction::BitCast: 3149 MadeChange |= foldBitcastShuffle(I); 3150 break; 3151 default: 3152 MadeChange |= shrinkType(I); 3153 break; 3154 } 3155 } else { 3156 switch (Opcode) { 3157 case Instruction::Call: 3158 MadeChange |= foldShuffleFromReductions(I); 3159 MadeChange |= foldCastFromReductions(I); 3160 break; 3161 case Instruction::ICmp: 3162 case Instruction::FCmp: 3163 MadeChange |= foldExtractExtract(I); 3164 break; 3165 case Instruction::Or: 3166 MadeChange |= foldConcatOfBoolMasks(I); 3167 [[fallthrough]]; 3168 default: 3169 if (Instruction::isBinaryOp(Opcode)) { 3170 MadeChange |= foldExtractExtract(I); 3171 MadeChange |= foldExtractedCmps(I); 3172 } 3173 break; 3174 } 3175 } 3176 }; 3177 3178 for (BasicBlock &BB : F) { 3179 // Ignore unreachable basic blocks. 3180 if (!DT.isReachableFromEntry(&BB)) 3181 continue; 3182 // Use early increment range so that we can erase instructions in loop. 3183 for (Instruction &I : make_early_inc_range(BB)) { 3184 if (I.isDebugOrPseudoInst()) 3185 continue; 3186 FoldInst(I); 3187 } 3188 } 3189 3190 while (!Worklist.isEmpty()) { 3191 Instruction *I = Worklist.removeOne(); 3192 if (!I) 3193 continue; 3194 3195 if (isInstructionTriviallyDead(I)) { 3196 eraseInstruction(*I); 3197 continue; 3198 } 3199 3200 FoldInst(*I); 3201 } 3202 3203 return MadeChange; 3204 } 3205 3206 PreservedAnalyses VectorCombinePass::run(Function &F, 3207 FunctionAnalysisManager &FAM) { 3208 auto &AC = FAM.getResult<AssumptionAnalysis>(F); 3209 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 3210 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 3211 AAResults &AA = FAM.getResult<AAManager>(F); 3212 const DataLayout *DL = &F.getDataLayout(); 3213 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput, 3214 TryEarlyFoldsOnly); 3215 if (!Combiner.run()) 3216 return PreservedAnalyses::all(); 3217 PreservedAnalyses PA; 3218 PA.preserveSet<CFGAnalyses>(); 3219 return PA; 3220 } 3221