1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass optimizes scalar/vector interactions using target cost models. The 10 // transforms implemented here may not fit in traditional loop-based or SLP 11 // vectorization passes. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Vectorize/VectorCombine.h" 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/ADT/Statistic.h" 20 #include "llvm/Analysis/AssumptionCache.h" 21 #include "llvm/Analysis/BasicAliasAnalysis.h" 22 #include "llvm/Analysis/GlobalsModRef.h" 23 #include "llvm/Analysis/Loads.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/ValueTracking.h" 26 #include "llvm/Analysis/VectorUtils.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/PatternMatch.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Transforms/Utils/Local.h" 33 #include "llvm/Transforms/Utils/LoopUtils.h" 34 #include <numeric> 35 #include <queue> 36 37 #define DEBUG_TYPE "vector-combine" 38 #include "llvm/Transforms/Utils/InstructionWorklist.h" 39 40 using namespace llvm; 41 using namespace llvm::PatternMatch; 42 43 STATISTIC(NumVecLoad, "Number of vector loads formed"); 44 STATISTIC(NumVecCmp, "Number of vector compares formed"); 45 STATISTIC(NumVecBO, "Number of vector binops formed"); 46 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed"); 47 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); 48 STATISTIC(NumScalarBO, "Number of scalar binops formed"); 49 STATISTIC(NumScalarCmp, "Number of scalar compares formed"); 50 51 static cl::opt<bool> DisableVectorCombine( 52 "disable-vector-combine", cl::init(false), cl::Hidden, 53 cl::desc("Disable all vector combine transforms")); 54 55 static cl::opt<bool> DisableBinopExtractShuffle( 56 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, 57 cl::desc("Disable binop extract to shuffle transforms")); 58 59 static cl::opt<unsigned> MaxInstrsToScan( 60 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, 61 cl::desc("Max number of instructions to scan for vector combining.")); 62 63 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max(); 64 65 namespace { 66 class VectorCombine { 67 public: 68 VectorCombine(Function &F, const TargetTransformInfo &TTI, 69 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, 70 const DataLayout *DL, TTI::TargetCostKind CostKind, 71 bool TryEarlyFoldsOnly) 72 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL), 73 CostKind(CostKind), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {} 74 75 bool run(); 76 77 private: 78 Function &F; 79 IRBuilder<> Builder; 80 const TargetTransformInfo &TTI; 81 const DominatorTree &DT; 82 AAResults &AA; 83 AssumptionCache &AC; 84 const DataLayout *DL; 85 TTI::TargetCostKind CostKind; 86 87 /// If true, only perform beneficial early IR transforms. Do not introduce new 88 /// vector operations. 89 bool TryEarlyFoldsOnly; 90 91 InstructionWorklist Worklist; 92 93 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction" 94 // parameter. That should be updated to specific sub-classes because the 95 // run loop was changed to dispatch on opcode. 96 bool vectorizeLoadInsert(Instruction &I); 97 bool widenSubvectorLoad(Instruction &I); 98 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, 99 ExtractElementInst *Ext1, 100 unsigned PreferredExtractIndex) const; 101 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 102 const Instruction &I, 103 ExtractElementInst *&ConvertToShuffle, 104 unsigned PreferredExtractIndex); 105 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 106 Instruction &I); 107 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, 108 Instruction &I); 109 bool foldExtractExtract(Instruction &I); 110 bool foldInsExtFNeg(Instruction &I); 111 bool foldInsExtVectorToShuffle(Instruction &I); 112 bool foldBitcastShuffle(Instruction &I); 113 bool scalarizeBinopOrCmp(Instruction &I); 114 bool scalarizeVPIntrinsic(Instruction &I); 115 bool foldExtractedCmps(Instruction &I); 116 bool foldSingleElementStore(Instruction &I); 117 bool scalarizeLoadExtract(Instruction &I); 118 bool foldConcatOfBoolMasks(Instruction &I); 119 bool foldPermuteOfBinops(Instruction &I); 120 bool foldShuffleOfBinops(Instruction &I); 121 bool foldShuffleOfCastops(Instruction &I); 122 bool foldShuffleOfShuffles(Instruction &I); 123 bool foldShuffleOfIntrinsics(Instruction &I); 124 bool foldShuffleToIdentity(Instruction &I); 125 bool foldShuffleFromReductions(Instruction &I); 126 bool foldCastFromReductions(Instruction &I); 127 bool foldSelectShuffle(Instruction &I, bool FromReduction = false); 128 bool shrinkType(Instruction &I); 129 130 void replaceValue(Value &Old, Value &New) { 131 Old.replaceAllUsesWith(&New); 132 if (auto *NewI = dyn_cast<Instruction>(&New)) { 133 New.takeName(&Old); 134 Worklist.pushUsersToWorkList(*NewI); 135 Worklist.pushValue(NewI); 136 } 137 Worklist.pushValue(&Old); 138 } 139 140 void eraseInstruction(Instruction &I) { 141 for (Value *Op : I.operands()) 142 Worklist.pushValue(Op); 143 Worklist.remove(&I); 144 I.eraseFromParent(); 145 } 146 }; 147 } // namespace 148 149 /// Return the source operand of a potentially bitcasted value. If there is no 150 /// bitcast, return the input value itself. 151 static Value *peekThroughBitcasts(Value *V) { 152 while (auto *BitCast = dyn_cast<BitCastInst>(V)) 153 V = BitCast->getOperand(0); 154 return V; 155 } 156 157 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { 158 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. 159 // The widened load may load data from dirty regions or create data races 160 // non-existent in the source. 161 if (!Load || !Load->isSimple() || !Load->hasOneUse() || 162 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || 163 mustSuppressSpeculation(*Load)) 164 return false; 165 166 // We are potentially transforming byte-sized (8-bit) memory accesses, so make 167 // sure we have all of our type-based constraints in place for this target. 168 Type *ScalarTy = Load->getType()->getScalarType(); 169 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 170 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 171 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || 172 ScalarSize % 8 != 0) 173 return false; 174 175 return true; 176 } 177 178 bool VectorCombine::vectorizeLoadInsert(Instruction &I) { 179 // Match insert into fixed vector of scalar value. 180 // TODO: Handle non-zero insert index. 181 Value *Scalar; 182 if (!match(&I, 183 m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt()))) 184 return false; 185 186 // Optionally match an extract from another vector. 187 Value *X; 188 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); 189 if (!HasExtract) 190 X = Scalar; 191 192 auto *Load = dyn_cast<LoadInst>(X); 193 if (!canWidenLoad(Load, TTI)) 194 return false; 195 196 Type *ScalarTy = Scalar->getType(); 197 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); 198 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); 199 200 // Check safety of replacing the scalar load with a larger vector load. 201 // We use minimal alignment (maximum flexibility) because we only care about 202 // the dereferenceable region. When calculating cost and creating a new op, 203 // we may use a larger value based on alignment attributes. 204 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 205 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 206 207 unsigned MinVecNumElts = MinVectorSize / ScalarSize; 208 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); 209 unsigned OffsetEltIndex = 0; 210 Align Alignment = Load->getAlign(); 211 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 212 &DT)) { 213 // It is not safe to load directly from the pointer, but we can still peek 214 // through gep offsets and check if it safe to load from a base address with 215 // updated alignment. If it is, we can shuffle the element(s) into place 216 // after loading. 217 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType()); 218 APInt Offset(OffsetBitWidth, 0); 219 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset); 220 221 // We want to shuffle the result down from a high element of a vector, so 222 // the offset must be positive. 223 if (Offset.isNegative()) 224 return false; 225 226 // The offset must be a multiple of the scalar element to shuffle cleanly 227 // in the element's size. 228 uint64_t ScalarSizeInBytes = ScalarSize / 8; 229 if (Offset.urem(ScalarSizeInBytes) != 0) 230 return false; 231 232 // If we load MinVecNumElts, will our target element still be loaded? 233 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue(); 234 if (OffsetEltIndex >= MinVecNumElts) 235 return false; 236 237 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, 238 &DT)) 239 return false; 240 241 // Update alignment with offset value. Note that the offset could be negated 242 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but 243 // negation does not change the result of the alignment calculation. 244 Alignment = commonAlignment(Alignment, Offset.getZExtValue()); 245 } 246 247 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 248 // Use the greater of the alignment on the load or its source pointer. 249 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 250 Type *LoadTy = Load->getType(); 251 unsigned AS = Load->getPointerAddressSpace(); 252 InstructionCost OldCost = 253 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 254 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); 255 OldCost += 256 TTI.getScalarizationOverhead(MinVecTy, DemandedElts, 257 /* Insert */ true, HasExtract, CostKind); 258 259 // New pattern: load VecPtr 260 InstructionCost NewCost = 261 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind); 262 // Optionally, we are shuffling the loaded vector element(s) into place. 263 // For the mask set everything but element 0 to undef to prevent poison from 264 // propagating from the extra loaded memory. This will also optionally 265 // shrink/grow the vector from the loaded size to the output size. 266 // We assume this operation has no cost in codegen if there was no offset. 267 // Note that we could use freeze to avoid poison problems, but then we might 268 // still need a shuffle to change the vector size. 269 auto *Ty = cast<FixedVectorType>(I.getType()); 270 unsigned OutputNumElts = Ty->getNumElements(); 271 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem); 272 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); 273 Mask[0] = OffsetEltIndex; 274 if (OffsetEltIndex) 275 NewCost += 276 TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, MinVecTy, Mask, CostKind); 277 278 // We can aggressively convert to the vector form because the backend can 279 // invert this transform if it does not result in a performance win. 280 if (OldCost < NewCost || !NewCost.isValid()) 281 return false; 282 283 // It is safe and potentially profitable to load a vector directly: 284 // inselt undef, load Scalar, 0 --> load VecPtr 285 IRBuilder<> Builder(Load); 286 Value *CastedPtr = 287 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 288 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment); 289 VecLd = Builder.CreateShuffleVector(VecLd, Mask); 290 291 replaceValue(I, *VecLd); 292 ++NumVecLoad; 293 return true; 294 } 295 296 /// If we are loading a vector and then inserting it into a larger vector with 297 /// undefined elements, try to load the larger vector and eliminate the insert. 298 /// This removes a shuffle in IR and may allow combining of other loaded values. 299 bool VectorCombine::widenSubvectorLoad(Instruction &I) { 300 // Match subvector insert of fixed vector. 301 auto *Shuf = cast<ShuffleVectorInst>(&I); 302 if (!Shuf->isIdentityWithPadding()) 303 return false; 304 305 // Allow a non-canonical shuffle mask that is choosing elements from op1. 306 unsigned NumOpElts = 307 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements(); 308 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) { 309 return M >= (int)(NumOpElts); 310 }); 311 312 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex)); 313 if (!canWidenLoad(Load, TTI)) 314 return false; 315 316 // We use minimal alignment (maximum flexibility) because we only care about 317 // the dereferenceable region. When calculating cost and creating a new op, 318 // we may use a larger value based on alignment attributes. 319 auto *Ty = cast<FixedVectorType>(I.getType()); 320 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); 321 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); 322 Align Alignment = Load->getAlign(); 323 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT)) 324 return false; 325 326 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); 327 Type *LoadTy = Load->getType(); 328 unsigned AS = Load->getPointerAddressSpace(); 329 330 // Original pattern: insert_subvector (load PtrOp) 331 // This conservatively assumes that the cost of a subvector insert into an 332 // undef value is 0. We could add that cost if the cost model accurately 333 // reflects the real cost of that operation. 334 InstructionCost OldCost = 335 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind); 336 337 // New pattern: load PtrOp 338 InstructionCost NewCost = 339 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind); 340 341 // We can aggressively convert to the vector form because the backend can 342 // invert this transform if it does not result in a performance win. 343 if (OldCost < NewCost || !NewCost.isValid()) 344 return false; 345 346 IRBuilder<> Builder(Load); 347 Value *CastedPtr = 348 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS)); 349 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment); 350 replaceValue(I, *VecLd); 351 ++NumVecLoad; 352 return true; 353 } 354 355 /// Determine which, if any, of the inputs should be replaced by a shuffle 356 /// followed by extract from a different index. 357 ExtractElementInst *VectorCombine::getShuffleExtract( 358 ExtractElementInst *Ext0, ExtractElementInst *Ext1, 359 unsigned PreferredExtractIndex = InvalidIndex) const { 360 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 361 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 362 assert(Index0C && Index1C && "Expected constant extract indexes"); 363 364 unsigned Index0 = Index0C->getZExtValue(); 365 unsigned Index1 = Index1C->getZExtValue(); 366 367 // If the extract indexes are identical, no shuffle is needed. 368 if (Index0 == Index1) 369 return nullptr; 370 371 Type *VecTy = Ext0->getVectorOperand()->getType(); 372 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); 373 InstructionCost Cost0 = 374 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 375 InstructionCost Cost1 = 376 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 377 378 // If both costs are invalid no shuffle is needed 379 if (!Cost0.isValid() && !Cost1.isValid()) 380 return nullptr; 381 382 // We are extracting from 2 different indexes, so one operand must be shuffled 383 // before performing a vector operation and/or extract. The more expensive 384 // extract will be replaced by a shuffle. 385 if (Cost0 > Cost1) 386 return Ext0; 387 if (Cost1 > Cost0) 388 return Ext1; 389 390 // If the costs are equal and there is a preferred extract index, shuffle the 391 // opposite operand. 392 if (PreferredExtractIndex == Index0) 393 return Ext1; 394 if (PreferredExtractIndex == Index1) 395 return Ext0; 396 397 // Otherwise, replace the extract with the higher index. 398 return Index0 > Index1 ? Ext0 : Ext1; 399 } 400 401 /// Compare the relative costs of 2 extracts followed by scalar operation vs. 402 /// vector operation(s) followed by extract. Return true if the existing 403 /// instructions are cheaper than a vector alternative. Otherwise, return false 404 /// and if one of the extracts should be transformed to a shufflevector, set 405 /// \p ConvertToShuffle to that extract instruction. 406 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, 407 ExtractElementInst *Ext1, 408 const Instruction &I, 409 ExtractElementInst *&ConvertToShuffle, 410 unsigned PreferredExtractIndex) { 411 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); 412 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); 413 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes"); 414 415 unsigned Opcode = I.getOpcode(); 416 Value *Ext0Src = Ext0->getVectorOperand(); 417 Value *Ext1Src = Ext1->getVectorOperand(); 418 Type *ScalarTy = Ext0->getType(); 419 auto *VecTy = cast<VectorType>(Ext0Src->getType()); 420 InstructionCost ScalarOpCost, VectorOpCost; 421 422 // Get cost estimates for scalar and vector versions of the operation. 423 bool IsBinOp = Instruction::isBinaryOp(Opcode); 424 if (IsBinOp) { 425 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 426 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 427 } else { 428 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && 429 "Expected a compare"); 430 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 431 ScalarOpCost = TTI.getCmpSelInstrCost( 432 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 433 VectorOpCost = TTI.getCmpSelInstrCost( 434 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 435 } 436 437 // Get cost estimates for the extract elements. These costs will factor into 438 // both sequences. 439 unsigned Ext0Index = Ext0IndexC->getZExtValue(); 440 unsigned Ext1Index = Ext1IndexC->getZExtValue(); 441 442 InstructionCost Extract0Cost = 443 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index); 444 InstructionCost Extract1Cost = 445 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index); 446 447 // A more expensive extract will always be replaced by a splat shuffle. 448 // For example, if Ext0 is more expensive: 449 // opcode (extelt V0, Ext0), (ext V1, Ext1) --> 450 // extelt (opcode (splat V0, Ext0), V1), Ext1 451 // TODO: Evaluate whether that always results in lowest cost. Alternatively, 452 // check the cost of creating a broadcast shuffle and shuffling both 453 // operands to element 0. 454 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index; 455 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index; 456 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost); 457 458 // Extra uses of the extracts mean that we include those costs in the 459 // vector total because those instructions will not be eliminated. 460 InstructionCost OldCost, NewCost; 461 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) { 462 // Handle a special case. If the 2 extracts are identical, adjust the 463 // formulas to account for that. The extra use charge allows for either the 464 // CSE'd pattern or an unoptimized form with identical values: 465 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C 466 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) 467 : !Ext0->hasOneUse() || !Ext1->hasOneUse(); 468 OldCost = CheapExtractCost + ScalarOpCost; 469 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; 470 } else { 471 // Handle the general case. Each extract is actually a different value: 472 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C 473 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; 474 NewCost = VectorOpCost + CheapExtractCost + 475 !Ext0->hasOneUse() * Extract0Cost + 476 !Ext1->hasOneUse() * Extract1Cost; 477 } 478 479 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); 480 if (ConvertToShuffle) { 481 if (IsBinOp && DisableBinopExtractShuffle) 482 return true; 483 484 // If we are extracting from 2 different indexes, then one operand must be 485 // shuffled before performing the vector operation. The shuffle mask is 486 // poison except for 1 lane that is being translated to the remaining 487 // extraction lane. Therefore, it is a splat shuffle. Ex: 488 // ShufMask = { poison, poison, 0, poison } 489 // TODO: The cost model has an option for a "broadcast" shuffle 490 // (splat-from-element-0), but no option for a more general splat. 491 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) { 492 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(), 493 PoisonMaskElem); 494 ShuffleMask[BestInsIndex] = BestExtIndex; 495 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 496 VecTy, ShuffleMask, CostKind, 0, nullptr, 497 {ConvertToShuffle}); 498 } else { 499 NewCost += 500 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy, 501 {}, CostKind, 0, nullptr, {ConvertToShuffle}); 502 } 503 } 504 505 // Aggressively form a vector op if the cost is equal because the transform 506 // may enable further optimization. 507 // Codegen can reverse this transform (scalarize) if it was not profitable. 508 return OldCost < NewCost; 509 } 510 511 /// Create a shuffle that translates (shifts) 1 element from the input vector 512 /// to a new element location. 513 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex, 514 unsigned NewIndex, IRBuilder<> &Builder) { 515 // The shuffle mask is poison except for 1 lane that is being translated 516 // to the new element index. Example for OldIndex == 2 and NewIndex == 0: 517 // ShufMask = { 2, poison, poison, poison } 518 auto *VecTy = cast<FixedVectorType>(Vec->getType()); 519 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 520 ShufMask[NewIndex] = OldIndex; 521 return Builder.CreateShuffleVector(Vec, ShufMask, "shift"); 522 } 523 524 /// Given an extract element instruction with constant index operand, shuffle 525 /// the source vector (shift the scalar element) to a NewIndex for extraction. 526 /// Return null if the input can be constant folded, so that we are not creating 527 /// unnecessary instructions. 528 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt, 529 unsigned NewIndex, 530 IRBuilder<> &Builder) { 531 // Shufflevectors can only be created for fixed-width vectors. 532 Value *X = ExtElt->getVectorOperand(); 533 if (!isa<FixedVectorType>(X->getType())) 534 return nullptr; 535 536 // If the extract can be constant-folded, this code is unsimplified. Defer 537 // to other passes to handle that. 538 Value *C = ExtElt->getIndexOperand(); 539 assert(isa<ConstantInt>(C) && "Expected a constant index operand"); 540 if (isa<Constant>(X)) 541 return nullptr; 542 543 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(), 544 NewIndex, Builder); 545 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex)); 546 } 547 548 /// Try to reduce extract element costs by converting scalar compares to vector 549 /// compares followed by extract. 550 /// cmp (ext0 V0, C), (ext1 V1, C) 551 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0, 552 ExtractElementInst *Ext1, Instruction &I) { 553 assert(isa<CmpInst>(&I) && "Expected a compare"); 554 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 555 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 556 "Expected matching constant extract indexes"); 557 558 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C 559 ++NumVecCmp; 560 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate(); 561 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 562 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1); 563 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand()); 564 replaceValue(I, *NewExt); 565 } 566 567 /// Try to reduce extract element costs by converting scalar binops to vector 568 /// binops followed by extract. 569 /// bo (ext0 V0, C), (ext1 V1, C) 570 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0, 571 ExtractElementInst *Ext1, Instruction &I) { 572 assert(isa<BinaryOperator>(&I) && "Expected a binary operator"); 573 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() == 574 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() && 575 "Expected matching constant extract indexes"); 576 577 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C 578 ++NumVecBO; 579 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand(); 580 Value *VecBO = 581 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1); 582 583 // All IR flags are safe to back-propagate because any potential poison 584 // created in unused vector elements is discarded by the extract. 585 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO)) 586 VecBOInst->copyIRFlags(&I); 587 588 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand()); 589 replaceValue(I, *NewExt); 590 } 591 592 /// Match an instruction with extracted vector operands. 593 bool VectorCombine::foldExtractExtract(Instruction &I) { 594 // It is not safe to transform things like div, urem, etc. because we may 595 // create undefined behavior when executing those on unknown vector elements. 596 if (!isSafeToSpeculativelyExecute(&I)) 597 return false; 598 599 Instruction *I0, *I1; 600 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 601 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) && 602 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1)))) 603 return false; 604 605 Value *V0, *V1; 606 uint64_t C0, C1; 607 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) || 608 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) || 609 V0->getType() != V1->getType()) 610 return false; 611 612 // If the scalar value 'I' is going to be re-inserted into a vector, then try 613 // to create an extract to that same element. The extract/insert can be 614 // reduced to a "select shuffle". 615 // TODO: If we add a larger pattern match that starts from an insert, this 616 // probably becomes unnecessary. 617 auto *Ext0 = cast<ExtractElementInst>(I0); 618 auto *Ext1 = cast<ExtractElementInst>(I1); 619 uint64_t InsertIndex = InvalidIndex; 620 if (I.hasOneUse()) 621 match(I.user_back(), 622 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex))); 623 624 ExtractElementInst *ExtractToChange; 625 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex)) 626 return false; 627 628 if (ExtractToChange) { 629 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0; 630 ExtractElementInst *NewExtract = 631 translateExtract(ExtractToChange, CheapExtractIdx, Builder); 632 if (!NewExtract) 633 return false; 634 if (ExtractToChange == Ext0) 635 Ext0 = NewExtract; 636 else 637 Ext1 = NewExtract; 638 } 639 640 if (Pred != CmpInst::BAD_ICMP_PREDICATE) 641 foldExtExtCmp(Ext0, Ext1, I); 642 else 643 foldExtExtBinop(Ext0, Ext1, I); 644 645 Worklist.push(Ext0); 646 Worklist.push(Ext1); 647 return true; 648 } 649 650 /// Try to replace an extract + scalar fneg + insert with a vector fneg + 651 /// shuffle. 652 bool VectorCombine::foldInsExtFNeg(Instruction &I) { 653 // Match an insert (op (extract)) pattern. 654 Value *DestVec; 655 uint64_t Index; 656 Instruction *FNeg; 657 if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)), 658 m_ConstantInt(Index)))) 659 return false; 660 661 // Note: This handles the canonical fneg instruction and "fsub -0.0, X". 662 Value *SrcVec; 663 Instruction *Extract; 664 if (!match(FNeg, m_FNeg(m_CombineAnd( 665 m_Instruction(Extract), 666 m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index)))))) 667 return false; 668 669 auto *VecTy = cast<FixedVectorType>(I.getType()); 670 auto *ScalarTy = VecTy->getScalarType(); 671 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType()); 672 if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType()) 673 return false; 674 675 // Ignore bogus insert/extract index. 676 unsigned NumElts = VecTy->getNumElements(); 677 if (Index >= NumElts) 678 return false; 679 680 // We are inserting the negated element into the same lane that we extracted 681 // from. This is equivalent to a select-shuffle that chooses all but the 682 // negated element from the destination vector. 683 SmallVector<int> Mask(NumElts); 684 std::iota(Mask.begin(), Mask.end(), 0); 685 Mask[Index] = Index + NumElts; 686 InstructionCost OldCost = 687 TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) + 688 TTI.getVectorInstrCost(I, VecTy, CostKind, Index); 689 690 // If the extract has one use, it will be eliminated, so count it in the 691 // original cost. If it has more than one use, ignore the cost because it will 692 // be the same before/after. 693 if (Extract->hasOneUse()) 694 OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index); 695 696 InstructionCost NewCost = 697 TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) + 698 TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask, CostKind); 699 700 bool NeedLenChg = SrcVecTy->getNumElements() != NumElts; 701 // If the lengths of the two vectors are not equal, 702 // we need to add a length-change vector. Add this cost. 703 SmallVector<int> SrcMask; 704 if (NeedLenChg) { 705 SrcMask.assign(NumElts, PoisonMaskElem); 706 SrcMask[Index] = Index; 707 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, 708 SrcVecTy, SrcMask, CostKind); 709 } 710 711 if (NewCost > OldCost) 712 return false; 713 714 Value *NewShuf; 715 // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index 716 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg); 717 if (NeedLenChg) { 718 // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask 719 Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask); 720 NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask); 721 } else { 722 // shuffle DestVec, (fneg SrcVec), Mask 723 NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask); 724 } 725 726 replaceValue(I, *NewShuf); 727 return true; 728 } 729 730 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the 731 /// destination type followed by shuffle. This can enable further transforms by 732 /// moving bitcasts or shuffles together. 733 bool VectorCombine::foldBitcastShuffle(Instruction &I) { 734 Value *V0, *V1; 735 ArrayRef<int> Mask; 736 if (!match(&I, m_BitCast(m_OneUse( 737 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask)))))) 738 return false; 739 740 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for 741 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle 742 // mask for scalable type is a splat or not. 743 // 2) Disallow non-vector casts. 744 // TODO: We could allow any shuffle. 745 auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); 746 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType()); 747 if (!DestTy || !SrcTy) 748 return false; 749 750 unsigned DestEltSize = DestTy->getScalarSizeInBits(); 751 unsigned SrcEltSize = SrcTy->getScalarSizeInBits(); 752 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0) 753 return false; 754 755 bool IsUnary = isa<UndefValue>(V1); 756 757 // For binary shuffles, only fold bitcast(shuffle(X,Y)) 758 // if it won't increase the number of bitcasts. 759 if (!IsUnary) { 760 auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType()); 761 auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType()); 762 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) && 763 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType())) 764 return false; 765 } 766 767 SmallVector<int, 16> NewMask; 768 if (DestEltSize <= SrcEltSize) { 769 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 770 // always be expanded to the equivalent form choosing narrower elements. 771 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask"); 772 unsigned ScaleFactor = SrcEltSize / DestEltSize; 773 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask); 774 } else { 775 // The bitcast is from narrow elements to wide elements. The shuffle mask 776 // must choose consecutive elements to allow casting first. 777 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask"); 778 unsigned ScaleFactor = DestEltSize / SrcEltSize; 779 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask)) 780 return false; 781 } 782 783 // Bitcast the shuffle src - keep its original width but using the destination 784 // scalar type. 785 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; 786 auto *NewShuffleTy = 787 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); 788 auto *OldShuffleTy = 789 FixedVectorType::get(SrcTy->getScalarType(), Mask.size()); 790 unsigned NumOps = IsUnary ? 1 : 2; 791 792 // The new shuffle must not cost more than the old shuffle. 793 TargetTransformInfo::ShuffleKind SK = 794 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc 795 : TargetTransformInfo::SK_PermuteTwoSrc; 796 797 InstructionCost DestCost = 798 TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CostKind) + 799 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy, 800 TargetTransformInfo::CastContextHint::None, 801 CostKind)); 802 InstructionCost SrcCost = 803 TTI.getShuffleCost(SK, SrcTy, Mask, CostKind) + 804 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy, 805 TargetTransformInfo::CastContextHint::None, 806 CostKind); 807 if (DestCost > SrcCost || !DestCost.isValid()) 808 return false; 809 810 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC' 811 ++NumShufOfBitcast; 812 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy); 813 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy); 814 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask); 815 replaceValue(I, *Shuf); 816 return true; 817 } 818 819 /// VP Intrinsics whose vector operands are both splat values may be simplified 820 /// into the scalar version of the operation and the result splatted. This 821 /// can lead to scalarization down the line. 822 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { 823 if (!isa<VPIntrinsic>(I)) 824 return false; 825 VPIntrinsic &VPI = cast<VPIntrinsic>(I); 826 Value *Op0 = VPI.getArgOperand(0); 827 Value *Op1 = VPI.getArgOperand(1); 828 829 if (!isSplatValue(Op0) || !isSplatValue(Op1)) 830 return false; 831 832 // Check getSplatValue early in this function, to avoid doing unnecessary 833 // work. 834 Value *ScalarOp0 = getSplatValue(Op0); 835 Value *ScalarOp1 = getSplatValue(Op1); 836 if (!ScalarOp0 || !ScalarOp1) 837 return false; 838 839 // For the binary VP intrinsics supported here, the result on disabled lanes 840 // is a poison value. For now, only do this simplification if all lanes 841 // are active. 842 // TODO: Relax the condition that all lanes are active by using insertelement 843 // on inactive lanes. 844 auto IsAllTrueMask = [](Value *MaskVal) { 845 if (Value *SplattedVal = getSplatValue(MaskVal)) 846 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 847 return ConstValue->isAllOnesValue(); 848 return false; 849 }; 850 if (!IsAllTrueMask(VPI.getArgOperand(2))) 851 return false; 852 853 // Check to make sure we support scalarization of the intrinsic 854 Intrinsic::ID IntrID = VPI.getIntrinsicID(); 855 if (!VPBinOpIntrinsic::isVPBinOp(IntrID)) 856 return false; 857 858 // Calculate cost of splatting both operands into vectors and the vector 859 // intrinsic 860 VectorType *VecTy = cast<VectorType>(VPI.getType()); 861 SmallVector<int> Mask; 862 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy)) 863 Mask.resize(FVTy->getNumElements(), 0); 864 InstructionCost SplatCost = 865 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + 866 TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask, 867 CostKind); 868 869 // Calculate the cost of the VP Intrinsic 870 SmallVector<Type *, 4> Args; 871 for (Value *V : VPI.args()) 872 Args.push_back(V->getType()); 873 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args); 874 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 875 InstructionCost OldCost = 2 * SplatCost + VectorOpCost; 876 877 // Determine scalar opcode 878 std::optional<unsigned> FunctionalOpcode = 879 VPI.getFunctionalOpcode(); 880 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; 881 if (!FunctionalOpcode) { 882 ScalarIntrID = VPI.getFunctionalIntrinsicID(); 883 if (!ScalarIntrID) 884 return false; 885 } 886 887 // Calculate cost of scalarizing 888 InstructionCost ScalarOpCost = 0; 889 if (ScalarIntrID) { 890 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args); 891 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind); 892 } else { 893 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode, 894 VecTy->getScalarType(), CostKind); 895 } 896 897 // The existing splats may be kept around if other instructions use them. 898 InstructionCost CostToKeepSplats = 899 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); 900 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; 901 902 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI 903 << "\n"); 904 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost 905 << ", Cost of scalarizing:" << NewCost << "\n"); 906 907 // We want to scalarize unless the vector variant actually has lower cost. 908 if (OldCost < NewCost || !NewCost.isValid()) 909 return false; 910 911 // Scalarize the intrinsic 912 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); 913 Value *EVL = VPI.getArgOperand(3); 914 915 // If the VP op might introduce UB or poison, we can scalarize it provided 916 // that we know the EVL > 0: If the EVL is zero, then the original VP op 917 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by 918 // scalarizing it. 919 bool SafeToSpeculate; 920 if (ScalarIntrID) 921 SafeToSpeculate = Intrinsic::getAttributes(I.getContext(), *ScalarIntrID) 922 .hasFnAttr(Attribute::AttrKind::Speculatable); 923 else 924 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode( 925 *FunctionalOpcode, &VPI, nullptr, &AC, &DT); 926 if (!SafeToSpeculate && 927 !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI))) 928 return false; 929 930 Value *ScalarVal = 931 ScalarIntrID 932 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID, 933 {ScalarOp0, ScalarOp1}) 934 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode), 935 ScalarOp0, ScalarOp1); 936 937 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal)); 938 return true; 939 } 940 941 /// Match a vector binop or compare instruction with at least one inserted 942 /// scalar operand and convert to scalar binop/cmp followed by insertelement. 943 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { 944 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 945 Value *Ins0, *Ins1; 946 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) && 947 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) 948 return false; 949 950 // Do not convert the vector condition of a vector select into a scalar 951 // condition. That may cause problems for codegen because of differences in 952 // boolean formats and register-file transfers. 953 // TODO: Can we account for that in the cost model? 954 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE; 955 if (IsCmp) 956 for (User *U : I.users()) 957 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value()))) 958 return false; 959 960 // Match against one or both scalar values being inserted into constant 961 // vectors: 962 // vec_op VecC0, (inselt VecC1, V1, Index) 963 // vec_op (inselt VecC0, V0, Index), VecC1 964 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) 965 // TODO: Deal with mismatched index constants and variable indexes? 966 Constant *VecC0 = nullptr, *VecC1 = nullptr; 967 Value *V0 = nullptr, *V1 = nullptr; 968 uint64_t Index0 = 0, Index1 = 0; 969 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0), 970 m_ConstantInt(Index0))) && 971 !match(Ins0, m_Constant(VecC0))) 972 return false; 973 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1), 974 m_ConstantInt(Index1))) && 975 !match(Ins1, m_Constant(VecC1))) 976 return false; 977 978 bool IsConst0 = !V0; 979 bool IsConst1 = !V1; 980 if (IsConst0 && IsConst1) 981 return false; 982 if (!IsConst0 && !IsConst1 && Index0 != Index1) 983 return false; 984 985 auto *VecTy0 = cast<VectorType>(Ins0->getType()); 986 auto *VecTy1 = cast<VectorType>(Ins1->getType()); 987 if (VecTy0->getElementCount().getKnownMinValue() <= Index0 || 988 VecTy1->getElementCount().getKnownMinValue() <= Index1) 989 return false; 990 991 // Bail for single insertion if it is a load. 992 // TODO: Handle this once getVectorInstrCost can cost for load/stores. 993 auto *I0 = dyn_cast_or_null<Instruction>(V0); 994 auto *I1 = dyn_cast_or_null<Instruction>(V1); 995 if ((IsConst0 && I1 && I1->mayReadFromMemory()) || 996 (IsConst1 && I0 && I0->mayReadFromMemory())) 997 return false; 998 999 uint64_t Index = IsConst0 ? Index1 : Index0; 1000 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType(); 1001 Type *VecTy = I.getType(); 1002 assert(VecTy->isVectorTy() && 1003 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) && 1004 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() || 1005 ScalarTy->isPointerTy()) && 1006 "Unexpected types for insert element into binop or cmp"); 1007 1008 unsigned Opcode = I.getOpcode(); 1009 InstructionCost ScalarOpCost, VectorOpCost; 1010 if (IsCmp) { 1011 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate(); 1012 ScalarOpCost = TTI.getCmpSelInstrCost( 1013 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind); 1014 VectorOpCost = TTI.getCmpSelInstrCost( 1015 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind); 1016 } else { 1017 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind); 1018 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind); 1019 } 1020 1021 // Get cost estimate for the insert element. This cost will factor into 1022 // both sequences. 1023 InstructionCost InsertCost = TTI.getVectorInstrCost( 1024 Instruction::InsertElement, VecTy, CostKind, Index); 1025 InstructionCost OldCost = 1026 (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; 1027 InstructionCost NewCost = ScalarOpCost + InsertCost + 1028 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) + 1029 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost); 1030 1031 // We want to scalarize unless the vector variant actually has lower cost. 1032 if (OldCost < NewCost || !NewCost.isValid()) 1033 return false; 1034 1035 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) --> 1036 // inselt NewVecC, (scalar_op V0, V1), Index 1037 if (IsCmp) 1038 ++NumScalarCmp; 1039 else 1040 ++NumScalarBO; 1041 1042 // For constant cases, extract the scalar element, this should constant fold. 1043 if (IsConst0) 1044 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index)); 1045 if (IsConst1) 1046 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index)); 1047 1048 Value *Scalar = 1049 IsCmp ? Builder.CreateCmp(Pred, V0, V1) 1050 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1); 1051 1052 Scalar->setName(I.getName() + ".scalar"); 1053 1054 // All IR flags are safe to back-propagate. There is no potential for extra 1055 // poison to be created by the scalar instruction. 1056 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar)) 1057 ScalarInst->copyIRFlags(&I); 1058 1059 // Fold the vector constants in the original vectors into a new base vector. 1060 Value *NewVecC = 1061 IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) 1062 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); 1063 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); 1064 replaceValue(I, *Insert); 1065 return true; 1066 } 1067 1068 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of 1069 /// a vector into vector operations followed by extract. Note: The SLP pass 1070 /// may miss this pattern because of implementation problems. 1071 bool VectorCombine::foldExtractedCmps(Instruction &I) { 1072 auto *BI = dyn_cast<BinaryOperator>(&I); 1073 1074 // We are looking for a scalar binop of booleans. 1075 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) 1076 if (!BI || !I.getType()->isIntegerTy(1)) 1077 return false; 1078 1079 // The compare predicates should match, and each compare should have a 1080 // constant operand. 1081 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); 1082 Instruction *I0, *I1; 1083 Constant *C0, *C1; 1084 CmpPredicate P0, P1; 1085 // FIXME: Use CmpPredicate::getMatching here. 1086 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) || 1087 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || 1088 P0 != static_cast<CmpInst::Predicate>(P1)) 1089 return false; 1090 1091 // The compare operands must be extracts of the same vector with constant 1092 // extract indexes. 1093 Value *X; 1094 uint64_t Index0, Index1; 1095 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) || 1096 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))) 1097 return false; 1098 1099 auto *Ext0 = cast<ExtractElementInst>(I0); 1100 auto *Ext1 = cast<ExtractElementInst>(I1); 1101 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind); 1102 if (!ConvertToShuf) 1103 return false; 1104 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) && 1105 "Unknown ExtractElementInst"); 1106 1107 // The original scalar pattern is: 1108 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) 1109 CmpInst::Predicate Pred = P0; 1110 unsigned CmpOpcode = 1111 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp; 1112 auto *VecTy = dyn_cast<FixedVectorType>(X->getType()); 1113 if (!VecTy) 1114 return false; 1115 1116 InstructionCost Ext0Cost = 1117 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); 1118 InstructionCost Ext1Cost = 1119 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); 1120 InstructionCost CmpCost = TTI.getCmpSelInstrCost( 1121 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred, 1122 CostKind); 1123 1124 InstructionCost OldCost = 1125 Ext0Cost + Ext1Cost + CmpCost * 2 + 1126 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind); 1127 1128 // The proposed vector pattern is: 1129 // vcmp = cmp Pred X, VecC 1130 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 1131 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; 1132 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; 1133 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType())); 1134 InstructionCost NewCost = TTI.getCmpSelInstrCost( 1135 CmpOpcode, X->getType(), CmpInst::makeCmpResultType(X->getType()), Pred, 1136 CostKind); 1137 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem); 1138 ShufMask[CheapIndex] = ExpensiveIndex; 1139 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, 1140 ShufMask, CostKind); 1141 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind); 1142 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex); 1143 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost; 1144 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost; 1145 1146 // Aggressively form vector ops if the cost is equal because the transform 1147 // may enable further optimization. 1148 // Codegen can reverse this transform (scalarize) if it was not profitable. 1149 if (OldCost < NewCost || !NewCost.isValid()) 1150 return false; 1151 1152 // Create a vector constant from the 2 scalar constants. 1153 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(), 1154 PoisonValue::get(VecTy->getElementType())); 1155 CmpC[Index0] = C0; 1156 CmpC[Index1] = C1; 1157 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); 1158 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); 1159 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp; 1160 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf; 1161 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS); 1162 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); 1163 replaceValue(I, *NewExt); 1164 ++NumVecCmpBO; 1165 return true; 1166 } 1167 1168 // Check if memory loc modified between two instrs in the same BB 1169 static bool isMemModifiedBetween(BasicBlock::iterator Begin, 1170 BasicBlock::iterator End, 1171 const MemoryLocation &Loc, AAResults &AA) { 1172 unsigned NumScanned = 0; 1173 return std::any_of(Begin, End, [&](const Instruction &Instr) { 1174 return isModSet(AA.getModRefInfo(&Instr, Loc)) || 1175 ++NumScanned > MaxInstrsToScan; 1176 }); 1177 } 1178 1179 namespace { 1180 /// Helper class to indicate whether a vector index can be safely scalarized and 1181 /// if a freeze needs to be inserted. 1182 class ScalarizationResult { 1183 enum class StatusTy { Unsafe, Safe, SafeWithFreeze }; 1184 1185 StatusTy Status; 1186 Value *ToFreeze; 1187 1188 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr) 1189 : Status(Status), ToFreeze(ToFreeze) {} 1190 1191 public: 1192 ScalarizationResult(const ScalarizationResult &Other) = default; 1193 ~ScalarizationResult() { 1194 assert(!ToFreeze && "freeze() not called with ToFreeze being set"); 1195 } 1196 1197 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; } 1198 static ScalarizationResult safe() { return {StatusTy::Safe}; } 1199 static ScalarizationResult safeWithFreeze(Value *ToFreeze) { 1200 return {StatusTy::SafeWithFreeze, ToFreeze}; 1201 } 1202 1203 /// Returns true if the index can be scalarize without requiring a freeze. 1204 bool isSafe() const { return Status == StatusTy::Safe; } 1205 /// Returns true if the index cannot be scalarized. 1206 bool isUnsafe() const { return Status == StatusTy::Unsafe; } 1207 /// Returns true if the index can be scalarize, but requires inserting a 1208 /// freeze. 1209 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; } 1210 1211 /// Reset the state of Unsafe and clear ToFreze if set. 1212 void discard() { 1213 ToFreeze = nullptr; 1214 Status = StatusTy::Unsafe; 1215 } 1216 1217 /// Freeze the ToFreeze and update the use in \p User to use it. 1218 void freeze(IRBuilder<> &Builder, Instruction &UserI) { 1219 assert(isSafeWithFreeze() && 1220 "should only be used when freezing is required"); 1221 assert(is_contained(ToFreeze->users(), &UserI) && 1222 "UserI must be a user of ToFreeze"); 1223 IRBuilder<>::InsertPointGuard Guard(Builder); 1224 Builder.SetInsertPoint(cast<Instruction>(&UserI)); 1225 Value *Frozen = 1226 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen"); 1227 for (Use &U : make_early_inc_range((UserI.operands()))) 1228 if (U.get() == ToFreeze) 1229 U.set(Frozen); 1230 1231 ToFreeze = nullptr; 1232 } 1233 }; 1234 } // namespace 1235 1236 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p 1237 /// Idx. \p Idx must access a valid vector element. 1238 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, 1239 Instruction *CtxI, 1240 AssumptionCache &AC, 1241 const DominatorTree &DT) { 1242 // We do checks for both fixed vector types and scalable vector types. 1243 // This is the number of elements of fixed vector types, 1244 // or the minimum number of elements of scalable vector types. 1245 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue(); 1246 1247 if (auto *C = dyn_cast<ConstantInt>(Idx)) { 1248 if (C->getValue().ult(NumElements)) 1249 return ScalarizationResult::safe(); 1250 return ScalarizationResult::unsafe(); 1251 } 1252 1253 unsigned IntWidth = Idx->getType()->getScalarSizeInBits(); 1254 APInt Zero(IntWidth, 0); 1255 APInt MaxElts(IntWidth, NumElements); 1256 ConstantRange ValidIndices(Zero, MaxElts); 1257 ConstantRange IdxRange(IntWidth, true); 1258 1259 if (isGuaranteedNotToBePoison(Idx, &AC)) { 1260 if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false, 1261 true, &AC, CtxI, &DT))) 1262 return ScalarizationResult::safe(); 1263 return ScalarizationResult::unsafe(); 1264 } 1265 1266 // If the index may be poison, check if we can insert a freeze before the 1267 // range of the index is restricted. 1268 Value *IdxBase; 1269 ConstantInt *CI; 1270 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) { 1271 IdxRange = IdxRange.binaryAnd(CI->getValue()); 1272 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) { 1273 IdxRange = IdxRange.urem(CI->getValue()); 1274 } 1275 1276 if (ValidIndices.contains(IdxRange)) 1277 return ScalarizationResult::safeWithFreeze(IdxBase); 1278 return ScalarizationResult::unsafe(); 1279 } 1280 1281 /// The memory operation on a vector of \p ScalarType had alignment of 1282 /// \p VectorAlignment. Compute the maximal, but conservatively correct, 1283 /// alignment that will be valid for the memory operation on a single scalar 1284 /// element of the same type with index \p Idx. 1285 static Align computeAlignmentAfterScalarization(Align VectorAlignment, 1286 Type *ScalarType, Value *Idx, 1287 const DataLayout &DL) { 1288 if (auto *C = dyn_cast<ConstantInt>(Idx)) 1289 return commonAlignment(VectorAlignment, 1290 C->getZExtValue() * DL.getTypeStoreSize(ScalarType)); 1291 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType)); 1292 } 1293 1294 // Combine patterns like: 1295 // %0 = load <4 x i32>, <4 x i32>* %a 1296 // %1 = insertelement <4 x i32> %0, i32 %b, i32 1 1297 // store <4 x i32> %1, <4 x i32>* %a 1298 // to: 1299 // %0 = bitcast <4 x i32>* %a to i32* 1300 // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 1301 // store i32 %b, i32* %1 1302 bool VectorCombine::foldSingleElementStore(Instruction &I) { 1303 auto *SI = cast<StoreInst>(&I); 1304 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType())) 1305 return false; 1306 1307 // TODO: Combine more complicated patterns (multiple insert) by referencing 1308 // TargetTransformInfo. 1309 Instruction *Source; 1310 Value *NewElement; 1311 Value *Idx; 1312 if (!match(SI->getValueOperand(), 1313 m_InsertElt(m_Instruction(Source), m_Value(NewElement), 1314 m_Value(Idx)))) 1315 return false; 1316 1317 if (auto *Load = dyn_cast<LoadInst>(Source)) { 1318 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType()); 1319 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); 1320 // Don't optimize for atomic/volatile load or store. Ensure memory is not 1321 // modified between, vector type matches store size, and index is inbounds. 1322 if (!Load->isSimple() || Load->getParent() != SI->getParent() || 1323 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || 1324 SrcAddr != SI->getPointerOperand()->stripPointerCasts()) 1325 return false; 1326 1327 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT); 1328 if (ScalarizableIdx.isUnsafe() || 1329 isMemModifiedBetween(Load->getIterator(), SI->getIterator(), 1330 MemoryLocation::get(SI), AA)) 1331 return false; 1332 1333 if (ScalarizableIdx.isSafeWithFreeze()) 1334 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx)); 1335 Value *GEP = Builder.CreateInBoundsGEP( 1336 SI->getValueOperand()->getType(), SI->getPointerOperand(), 1337 {ConstantInt::get(Idx->getType(), 0), Idx}); 1338 StoreInst *NSI = Builder.CreateStore(NewElement, GEP); 1339 NSI->copyMetadata(*SI); 1340 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1341 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx, 1342 *DL); 1343 NSI->setAlignment(ScalarOpAlignment); 1344 replaceValue(I, *NSI); 1345 eraseInstruction(I); 1346 return true; 1347 } 1348 1349 return false; 1350 } 1351 1352 /// Try to scalarize vector loads feeding extractelement instructions. 1353 bool VectorCombine::scalarizeLoadExtract(Instruction &I) { 1354 Value *Ptr; 1355 if (!match(&I, m_Load(m_Value(Ptr)))) 1356 return false; 1357 1358 auto *VecTy = cast<VectorType>(I.getType()); 1359 auto *LI = cast<LoadInst>(&I); 1360 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) 1361 return false; 1362 1363 InstructionCost OriginalCost = 1364 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), 1365 LI->getPointerAddressSpace(), CostKind); 1366 InstructionCost ScalarizedCost = 0; 1367 1368 Instruction *LastCheckedInst = LI; 1369 unsigned NumInstChecked = 0; 1370 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze; 1371 auto FailureGuard = make_scope_exit([&]() { 1372 // If the transform is aborted, discard the ScalarizationResults. 1373 for (auto &Pair : NeedFreeze) 1374 Pair.second.discard(); 1375 }); 1376 1377 // Check if all users of the load are extracts with no memory modifications 1378 // between the load and the extract. Compute the cost of both the original 1379 // code and the scalarized version. 1380 for (User *U : LI->users()) { 1381 auto *UI = dyn_cast<ExtractElementInst>(U); 1382 if (!UI || UI->getParent() != LI->getParent()) 1383 return false; 1384 1385 // Check if any instruction between the load and the extract may modify 1386 // memory. 1387 if (LastCheckedInst->comesBefore(UI)) { 1388 for (Instruction &I : 1389 make_range(std::next(LI->getIterator()), UI->getIterator())) { 1390 // Bail out if we reached the check limit or the instruction may write 1391 // to memory. 1392 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory()) 1393 return false; 1394 NumInstChecked++; 1395 } 1396 LastCheckedInst = UI; 1397 } 1398 1399 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); 1400 if (ScalarIdx.isUnsafe()) 1401 return false; 1402 if (ScalarIdx.isSafeWithFreeze()) { 1403 NeedFreeze.try_emplace(UI, ScalarIdx); 1404 ScalarIdx.discard(); 1405 } 1406 1407 auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); 1408 OriginalCost += 1409 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, 1410 Index ? Index->getZExtValue() : -1); 1411 ScalarizedCost += 1412 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(), 1413 Align(1), LI->getPointerAddressSpace(), CostKind); 1414 ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType()); 1415 } 1416 1417 if (ScalarizedCost >= OriginalCost) 1418 return false; 1419 1420 // Replace extracts with narrow scalar loads. 1421 for (User *U : LI->users()) { 1422 auto *EI = cast<ExtractElementInst>(U); 1423 Value *Idx = EI->getOperand(1); 1424 1425 // Insert 'freeze' for poison indexes. 1426 auto It = NeedFreeze.find(EI); 1427 if (It != NeedFreeze.end()) 1428 It->second.freeze(Builder, *cast<Instruction>(Idx)); 1429 1430 Builder.SetInsertPoint(EI); 1431 Value *GEP = 1432 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx}); 1433 auto *NewLoad = cast<LoadInst>(Builder.CreateLoad( 1434 VecTy->getElementType(), GEP, EI->getName() + ".scalar")); 1435 1436 Align ScalarOpAlignment = computeAlignmentAfterScalarization( 1437 LI->getAlign(), VecTy->getElementType(), Idx, *DL); 1438 NewLoad->setAlignment(ScalarOpAlignment); 1439 1440 replaceValue(*EI, *NewLoad); 1441 } 1442 1443 FailureGuard.release(); 1444 return true; 1445 } 1446 1447 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" 1448 /// to "(bitcast (concat X, Y))" 1449 /// where X/Y are bitcasted from i1 mask vectors. 1450 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) { 1451 Type *Ty = I.getType(); 1452 if (!Ty->isIntegerTy()) 1453 return false; 1454 1455 // TODO: Add big endian test coverage 1456 if (DL->isBigEndian()) 1457 return false; 1458 1459 // Restrict to disjoint cases so the mask vectors aren't overlapping. 1460 Instruction *X, *Y; 1461 if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y)))) 1462 return false; 1463 1464 // Allow both sources to contain shl, to handle more generic pattern: 1465 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))" 1466 Value *SrcX; 1467 uint64_t ShAmtX = 0; 1468 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) && 1469 !match(X, m_OneUse( 1470 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))), 1471 m_ConstantInt(ShAmtX))))) 1472 return false; 1473 1474 Value *SrcY; 1475 uint64_t ShAmtY = 0; 1476 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) && 1477 !match(Y, m_OneUse( 1478 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))), 1479 m_ConstantInt(ShAmtY))))) 1480 return false; 1481 1482 // Canonicalize larger shift to the RHS. 1483 if (ShAmtX > ShAmtY) { 1484 std::swap(X, Y); 1485 std::swap(SrcX, SrcY); 1486 std::swap(ShAmtX, ShAmtY); 1487 } 1488 1489 // Ensure both sources are matching vXi1 bool mask types, and that the shift 1490 // difference is the mask width so they can be easily concatenated together. 1491 uint64_t ShAmtDiff = ShAmtY - ShAmtX; 1492 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0); 1493 unsigned BitWidth = Ty->getPrimitiveSizeInBits(); 1494 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType()); 1495 if (!MaskTy || SrcX->getType() != SrcY->getType() || 1496 !MaskTy->getElementType()->isIntegerTy(1) || 1497 MaskTy->getNumElements() != ShAmtDiff || 1498 MaskTy->getNumElements() > (BitWidth / 2)) 1499 return false; 1500 1501 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy); 1502 auto *ConcatIntTy = 1503 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements()); 1504 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff); 1505 1506 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements()); 1507 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 1508 1509 // TODO: Is it worth supporting multi use cases? 1510 InstructionCost OldCost = 0; 1511 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind); 1512 OldCost += 1513 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1514 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy, 1515 TTI::CastContextHint::None, CostKind); 1516 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy, 1517 TTI::CastContextHint::None, CostKind); 1518 1519 InstructionCost NewCost = 0; 1520 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, MaskTy, 1521 ConcatMask, CostKind); 1522 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy, 1523 TTI::CastContextHint::None, CostKind); 1524 if (Ty != ConcatIntTy) 1525 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy, 1526 TTI::CastContextHint::None, CostKind); 1527 if (ShAmtX > 0) 1528 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind); 1529 1530 if (NewCost > OldCost) 1531 return false; 1532 1533 // Build bool mask concatenation, bitcast back to scalar integer, and perform 1534 // any residual zero-extension or shifting. 1535 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask); 1536 Worklist.pushValue(Concat); 1537 1538 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy); 1539 1540 if (Ty != ConcatIntTy) { 1541 Worklist.pushValue(Result); 1542 Result = Builder.CreateZExt(Result, Ty); 1543 } 1544 1545 if (ShAmtX > 0) { 1546 Worklist.pushValue(Result); 1547 Result = Builder.CreateShl(Result, ShAmtX); 1548 } 1549 1550 replaceValue(I, *Result); 1551 return true; 1552 } 1553 1554 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef" 1555 /// --> "binop (shuffle), (shuffle)". 1556 bool VectorCombine::foldPermuteOfBinops(Instruction &I) { 1557 BinaryOperator *BinOp; 1558 ArrayRef<int> OuterMask; 1559 if (!match(&I, 1560 m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask)))) 1561 return false; 1562 1563 // Don't introduce poison into div/rem. 1564 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem)) 1565 return false; 1566 1567 Value *Op00, *Op01; 1568 ArrayRef<int> Mask0; 1569 if (!match(BinOp->getOperand(0), 1570 m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0))))) 1571 return false; 1572 1573 Value *Op10, *Op11; 1574 ArrayRef<int> Mask1; 1575 if (!match(BinOp->getOperand(1), 1576 m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1))))) 1577 return false; 1578 1579 Instruction::BinaryOps Opcode = BinOp->getOpcode(); 1580 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1581 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType()); 1582 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType()); 1583 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType()); 1584 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty) 1585 return false; 1586 1587 unsigned NumSrcElts = BinOpTy->getNumElements(); 1588 1589 // Don't accept shuffles that reference the second operand in 1590 // div/rem or if its an undef arg. 1591 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) && 1592 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; })) 1593 return false; 1594 1595 // Merge outer / inner shuffles. 1596 SmallVector<int> NewMask0, NewMask1; 1597 for (int M : OuterMask) { 1598 if (M < 0 || M >= (int)NumSrcElts) { 1599 NewMask0.push_back(PoisonMaskElem); 1600 NewMask1.push_back(PoisonMaskElem); 1601 } else { 1602 NewMask0.push_back(Mask0[M]); 1603 NewMask1.push_back(Mask1[M]); 1604 } 1605 } 1606 1607 // Try to merge shuffles across the binop if the new shuffles are not costly. 1608 InstructionCost OldCost = 1609 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) + 1610 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, BinOpTy, 1611 OuterMask, CostKind, 0, nullptr, {BinOp}, &I) + 1612 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, Mask0, 1613 CostKind, 0, nullptr, {Op00, Op01}, 1614 cast<Instruction>(BinOp->getOperand(0))) + 1615 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, Mask1, 1616 CostKind, 0, nullptr, {Op10, Op11}, 1617 cast<Instruction>(BinOp->getOperand(1))); 1618 1619 InstructionCost NewCost = 1620 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op0Ty, NewMask0, 1621 CostKind, 0, nullptr, {Op00, Op01}) + 1622 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, Op1Ty, NewMask1, 1623 CostKind, 0, nullptr, {Op10, Op11}) + 1624 TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind); 1625 1626 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I 1627 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1628 << "\n"); 1629 1630 // If costs are equal, still fold as we reduce instruction count. 1631 if (NewCost > OldCost) 1632 return false; 1633 1634 Value *Shuf0 = Builder.CreateShuffleVector(Op00, Op01, NewMask0); 1635 Value *Shuf1 = Builder.CreateShuffleVector(Op10, Op11, NewMask1); 1636 Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); 1637 1638 // Intersect flags from the old binops. 1639 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) 1640 NewInst->copyIRFlags(BinOp); 1641 1642 Worklist.pushValue(Shuf0); 1643 Worklist.pushValue(Shuf1); 1644 replaceValue(I, *NewBO); 1645 return true; 1646 } 1647 1648 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)". 1649 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)". 1650 bool VectorCombine::foldShuffleOfBinops(Instruction &I) { 1651 ArrayRef<int> OldMask; 1652 Instruction *LHS, *RHS; 1653 if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)), 1654 m_OneUse(m_Instruction(RHS)), m_Mask(OldMask)))) 1655 return false; 1656 1657 // TODO: Add support for addlike etc. 1658 if (LHS->getOpcode() != RHS->getOpcode()) 1659 return false; 1660 1661 Value *X, *Y, *Z, *W; 1662 bool IsCommutative = false; 1663 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE; 1664 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) && 1665 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) { 1666 auto *BO = cast<BinaryOperator>(LHS); 1667 // Don't introduce poison into div/rem. 1668 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem()) 1669 return false; 1670 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode()); 1671 } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) && 1672 match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) { 1673 IsCommutative = cast<CmpInst>(LHS)->isCommutative(); 1674 } else 1675 return false; 1676 1677 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1678 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType()); 1679 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType()); 1680 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType()) 1681 return false; 1682 1683 unsigned NumSrcElts = BinOpTy->getNumElements(); 1684 1685 // If we have something like "add X, Y" and "add Z, X", swap ops to match. 1686 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z)) 1687 std::swap(X, Y); 1688 1689 auto ConvertToUnary = [NumSrcElts](int &M) { 1690 if (M >= (int)NumSrcElts) 1691 M -= NumSrcElts; 1692 }; 1693 1694 SmallVector<int> NewMask0(OldMask); 1695 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc; 1696 if (X == Z) { 1697 llvm::for_each(NewMask0, ConvertToUnary); 1698 SK0 = TargetTransformInfo::SK_PermuteSingleSrc; 1699 Z = PoisonValue::get(BinOpTy); 1700 } 1701 1702 SmallVector<int> NewMask1(OldMask); 1703 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc; 1704 if (Y == W) { 1705 llvm::for_each(NewMask1, ConvertToUnary); 1706 SK1 = TargetTransformInfo::SK_PermuteSingleSrc; 1707 W = PoisonValue::get(BinOpTy); 1708 } 1709 1710 // Try to replace a binop with a shuffle if the shuffle is not costly. 1711 InstructionCost OldCost = 1712 TTI.getInstructionCost(LHS, CostKind) + 1713 TTI.getInstructionCost(RHS, CostKind) + 1714 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinResTy, 1715 OldMask, CostKind, 0, nullptr, {LHS, RHS}, &I); 1716 1717 InstructionCost NewCost = 1718 TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) + 1719 TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}); 1720 1721 if (Pred == CmpInst::BAD_ICMP_PREDICATE) { 1722 NewCost += 1723 TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind); 1724 } else { 1725 auto *ShuffleCmpTy = 1726 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy); 1727 NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, 1728 ShuffleDstTy, Pred, CostKind); 1729 } 1730 1731 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I 1732 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1733 << "\n"); 1734 1735 // If either shuffle will constant fold away, then fold for the same cost as 1736 // we will reduce the instruction count. 1737 bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) || 1738 (isa<Constant>(Y) && isa<Constant>(W)); 1739 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost)) 1740 return false; 1741 1742 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0); 1743 Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1); 1744 Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE 1745 ? Builder.CreateBinOp( 1746 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1) 1747 : Builder.CreateCmp(Pred, Shuf0, Shuf1); 1748 1749 // Intersect flags from the old binops. 1750 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { 1751 NewInst->copyIRFlags(LHS); 1752 NewInst->andIRFlags(RHS); 1753 } 1754 1755 Worklist.pushValue(Shuf0); 1756 Worklist.pushValue(Shuf1); 1757 replaceValue(I, *NewBO); 1758 return true; 1759 } 1760 1761 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand 1762 /// into "castop (shuffle)". 1763 bool VectorCombine::foldShuffleOfCastops(Instruction &I) { 1764 Value *V0, *V1; 1765 ArrayRef<int> OldMask; 1766 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask)))) 1767 return false; 1768 1769 auto *C0 = dyn_cast<CastInst>(V0); 1770 auto *C1 = dyn_cast<CastInst>(V1); 1771 if (!C0 || !C1) 1772 return false; 1773 1774 Instruction::CastOps Opcode = C0->getOpcode(); 1775 if (C0->getSrcTy() != C1->getSrcTy()) 1776 return false; 1777 1778 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds. 1779 if (Opcode != C1->getOpcode()) { 1780 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value()))) 1781 Opcode = Instruction::SExt; 1782 else 1783 return false; 1784 } 1785 1786 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1787 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy()); 1788 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy()); 1789 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy) 1790 return false; 1791 1792 unsigned NumSrcElts = CastSrcTy->getNumElements(); 1793 unsigned NumDstElts = CastDstTy->getNumElements(); 1794 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) && 1795 "Only bitcasts expected to alter src/dst element counts"); 1796 1797 // Check for bitcasting of unscalable vector types. 1798 // e.g. <32 x i40> -> <40 x i32> 1799 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 && 1800 (NumDstElts % NumSrcElts) != 0) 1801 return false; 1802 1803 SmallVector<int, 16> NewMask; 1804 if (NumSrcElts >= NumDstElts) { 1805 // The bitcast is from wide to narrow/equal elements. The shuffle mask can 1806 // always be expanded to the equivalent form choosing narrower elements. 1807 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask"); 1808 unsigned ScaleFactor = NumSrcElts / NumDstElts; 1809 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask); 1810 } else { 1811 // The bitcast is from narrow elements to wide elements. The shuffle mask 1812 // must choose consecutive elements to allow casting first. 1813 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask"); 1814 unsigned ScaleFactor = NumDstElts / NumSrcElts; 1815 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask)) 1816 return false; 1817 } 1818 1819 auto *NewShuffleDstTy = 1820 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size()); 1821 1822 // Try to replace a castop with a shuffle if the shuffle is not costly. 1823 InstructionCost CostC0 = 1824 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy, 1825 TTI::CastContextHint::None, CostKind); 1826 InstructionCost CostC1 = 1827 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy, 1828 TTI::CastContextHint::None, CostKind); 1829 InstructionCost OldCost = CostC0 + CostC1; 1830 OldCost += 1831 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, 1832 OldMask, CostKind, 0, nullptr, {}, &I); 1833 1834 InstructionCost NewCost = TTI.getShuffleCost( 1835 TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind); 1836 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy, 1837 TTI::CastContextHint::None, CostKind); 1838 if (!C0->hasOneUse()) 1839 NewCost += CostC0; 1840 if (!C1->hasOneUse()) 1841 NewCost += CostC1; 1842 1843 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I 1844 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1845 << "\n"); 1846 if (NewCost > OldCost) 1847 return false; 1848 1849 Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0), 1850 C1->getOperand(0), NewMask); 1851 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy); 1852 1853 // Intersect flags from the old casts. 1854 if (auto *NewInst = dyn_cast<Instruction>(Cast)) { 1855 NewInst->copyIRFlags(C0); 1856 NewInst->andIRFlags(C1); 1857 } 1858 1859 Worklist.pushValue(Shuf); 1860 replaceValue(I, *Cast); 1861 return true; 1862 } 1863 1864 /// Try to convert any of: 1865 /// "shuffle (shuffle x, undef), (shuffle y, undef)" 1866 /// "shuffle (shuffle x, undef), y" 1867 /// "shuffle x, (shuffle y, undef)" 1868 /// into "shuffle x, y". 1869 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { 1870 ArrayRef<int> OuterMask; 1871 Value *OuterV0, *OuterV1; 1872 if (!match(&I, 1873 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask)))) 1874 return false; 1875 1876 ArrayRef<int> InnerMask0, InnerMask1; 1877 Value *V0 = nullptr, *V1 = nullptr; 1878 UndefValue *U0 = nullptr, *U1 = nullptr; 1879 bool Match0 = match( 1880 OuterV0, m_Shuffle(m_Value(V0), m_UndefValue(U0), m_Mask(InnerMask0))); 1881 bool Match1 = match( 1882 OuterV1, m_Shuffle(m_Value(V1), m_UndefValue(U1), m_Mask(InnerMask1))); 1883 if (!Match0 && !Match1) 1884 return false; 1885 1886 V0 = Match0 ? V0 : OuterV0; 1887 V1 = Match1 ? V1 : OuterV1; 1888 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1889 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType()); 1890 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType()); 1891 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy || 1892 V0->getType() != V1->getType()) 1893 return false; 1894 1895 unsigned NumSrcElts = ShuffleSrcTy->getNumElements(); 1896 unsigned NumImmElts = ShuffleImmTy->getNumElements(); 1897 1898 // Bail if either inner masks reference a RHS undef arg. 1899 if ((Match0 && !isa<PoisonValue>(U0) && 1900 any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) || 1901 (Match1 && !isa<PoisonValue>(U1) && 1902 any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; }))) 1903 return false; 1904 1905 // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem, 1906 SmallVector<int, 16> NewMask(OuterMask); 1907 for (int &M : NewMask) { 1908 if (0 <= M && M < (int)NumImmElts) { 1909 if (Match0) 1910 M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M]; 1911 } else if (M >= (int)NumImmElts) { 1912 if (Match1) { 1913 if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts) 1914 M = PoisonMaskElem; 1915 else 1916 M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts); 1917 } 1918 } 1919 } 1920 1921 // Have we folded to an Identity shuffle? 1922 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) { 1923 replaceValue(I, *V0); 1924 return true; 1925 } 1926 1927 // Try to merge the shuffles if the new shuffle is not costly. 1928 InstructionCost InnerCost0 = 0; 1929 if (Match0) 1930 InnerCost0 = TTI.getShuffleCost( 1931 TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask0, 1932 CostKind, 0, nullptr, {V0, U0}, cast<ShuffleVectorInst>(OuterV0)); 1933 1934 InstructionCost InnerCost1 = 0; 1935 if (Match1) 1936 InnerCost1 = TTI.getShuffleCost( 1937 TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, InnerMask1, 1938 CostKind, 0, nullptr, {V1, U1}, cast<ShuffleVectorInst>(OuterV1)); 1939 1940 InstructionCost OuterCost = TTI.getShuffleCost( 1941 TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, OuterMask, CostKind, 1942 0, nullptr, {OuterV0, OuterV1}, &I); 1943 1944 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost; 1945 1946 InstructionCost NewCost = 1947 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, 1948 NewMask, CostKind, 0, nullptr, {V0, V1}); 1949 if (!OuterV0->hasOneUse()) 1950 NewCost += InnerCost0; 1951 if (!OuterV1->hasOneUse()) 1952 NewCost += InnerCost1; 1953 1954 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I 1955 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 1956 << "\n"); 1957 if (NewCost > OldCost) 1958 return false; 1959 1960 // Clear unused sources to poison. 1961 if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; })) 1962 V0 = PoisonValue::get(ShuffleSrcTy); 1963 if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; })) 1964 V1 = PoisonValue::get(ShuffleSrcTy); 1965 1966 Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask); 1967 replaceValue(I, *Shuf); 1968 return true; 1969 } 1970 1971 /// Try to convert 1972 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)". 1973 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) { 1974 Value *V0, *V1; 1975 ArrayRef<int> OldMask; 1976 if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)), 1977 m_Mask(OldMask)))) 1978 return false; 1979 1980 auto *II0 = dyn_cast<IntrinsicInst>(V0); 1981 auto *II1 = dyn_cast<IntrinsicInst>(V1); 1982 if (!II0 || !II1) 1983 return false; 1984 1985 Intrinsic::ID IID = II0->getIntrinsicID(); 1986 if (IID != II1->getIntrinsicID()) 1987 return false; 1988 1989 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); 1990 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType()); 1991 if (!ShuffleDstTy || !II0Ty) 1992 return false; 1993 1994 if (!isTriviallyVectorizable(IID)) 1995 return false; 1996 1997 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 1998 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) && 1999 II0->getArgOperand(I) != II1->getArgOperand(I)) 2000 return false; 2001 2002 InstructionCost OldCost = 2003 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) + 2004 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) + 2005 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, II0Ty, OldMask, 2006 CostKind, 0, nullptr, {II0, II1}, &I); 2007 2008 SmallVector<Type *> NewArgsTy; 2009 InstructionCost NewCost = 0; 2010 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2011 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2012 NewArgsTy.push_back(II0->getArgOperand(I)->getType()); 2013 } else { 2014 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType()); 2015 NewArgsTy.push_back(FixedVectorType::get(VecTy->getElementType(), 2016 VecTy->getNumElements() * 2)); 2017 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, 2018 VecTy, OldMask, CostKind); 2019 } 2020 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy); 2021 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind); 2022 2023 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I 2024 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 2025 << "\n"); 2026 2027 if (NewCost > OldCost) 2028 return false; 2029 2030 SmallVector<Value *> NewArgs; 2031 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) 2032 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) { 2033 NewArgs.push_back(II0->getArgOperand(I)); 2034 } else { 2035 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), 2036 II1->getArgOperand(I), OldMask); 2037 NewArgs.push_back(Shuf); 2038 Worklist.pushValue(Shuf); 2039 } 2040 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs); 2041 2042 // Intersect flags from the old intrinsics. 2043 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) { 2044 NewInst->copyIRFlags(II0); 2045 NewInst->andIRFlags(II1); 2046 } 2047 2048 replaceValue(I, *NewIntrinsic); 2049 return true; 2050 } 2051 2052 using InstLane = std::pair<Use *, int>; 2053 2054 static InstLane lookThroughShuffles(Use *U, int Lane) { 2055 while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) { 2056 unsigned NumElts = 2057 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements(); 2058 int M = SV->getMaskValue(Lane); 2059 if (M < 0) 2060 return {nullptr, PoisonMaskElem}; 2061 if (static_cast<unsigned>(M) < NumElts) { 2062 U = &SV->getOperandUse(0); 2063 Lane = M; 2064 } else { 2065 U = &SV->getOperandUse(1); 2066 Lane = M - NumElts; 2067 } 2068 } 2069 return InstLane{U, Lane}; 2070 } 2071 2072 static SmallVector<InstLane> 2073 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { 2074 SmallVector<InstLane> NItem; 2075 for (InstLane IL : Item) { 2076 auto [U, Lane] = IL; 2077 InstLane OpLane = 2078 U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op), 2079 Lane) 2080 : InstLane{nullptr, PoisonMaskElem}; 2081 NItem.emplace_back(OpLane); 2082 } 2083 return NItem; 2084 } 2085 2086 /// Detect concat of multiple values into a vector 2087 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind, 2088 const TargetTransformInfo &TTI) { 2089 auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); 2090 unsigned NumElts = Ty->getNumElements(); 2091 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) 2092 return false; 2093 2094 // Check that the concat is free, usually meaning that the type will be split 2095 // during legalization. 2096 SmallVector<int, 16> ConcatMask(NumElts * 2); 2097 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 2098 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, CostKind) != 0) 2099 return false; 2100 2101 unsigned NumSlices = Item.size() / NumElts; 2102 // Currently we generate a tree of shuffles for the concats, which limits us 2103 // to a power2. 2104 if (!isPowerOf2_32(NumSlices)) 2105 return false; 2106 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { 2107 Use *SliceV = Item[Slice * NumElts].first; 2108 if (!SliceV || SliceV->get()->getType() != Ty) 2109 return false; 2110 for (unsigned Elt = 0; Elt < NumElts; ++Elt) { 2111 auto [V, Lane] = Item[Slice * NumElts + Elt]; 2112 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) 2113 return false; 2114 } 2115 } 2116 return true; 2117 } 2118 2119 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, 2120 const SmallPtrSet<Use *, 4> &IdentityLeafs, 2121 const SmallPtrSet<Use *, 4> &SplatLeafs, 2122 const SmallPtrSet<Use *, 4> &ConcatLeafs, 2123 IRBuilder<> &Builder, 2124 const TargetTransformInfo *TTI) { 2125 auto [FrontU, FrontLane] = Item.front(); 2126 2127 if (IdentityLeafs.contains(FrontU)) { 2128 return FrontU->get(); 2129 } 2130 if (SplatLeafs.contains(FrontU)) { 2131 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); 2132 return Builder.CreateShuffleVector(FrontU->get(), Mask); 2133 } 2134 if (ConcatLeafs.contains(FrontU)) { 2135 unsigned NumElts = 2136 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); 2137 SmallVector<Value *> Values(Item.size() / NumElts, nullptr); 2138 for (unsigned S = 0; S < Values.size(); ++S) 2139 Values[S] = Item[S * NumElts].first->get(); 2140 2141 while (Values.size() > 1) { 2142 NumElts *= 2; 2143 SmallVector<int, 16> Mask(NumElts, 0); 2144 std::iota(Mask.begin(), Mask.end(), 0); 2145 SmallVector<Value *> NewValues(Values.size() / 2, nullptr); 2146 for (unsigned S = 0; S < NewValues.size(); ++S) 2147 NewValues[S] = 2148 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); 2149 Values = NewValues; 2150 } 2151 return Values[0]; 2152 } 2153 2154 auto *I = cast<Instruction>(FrontU->get()); 2155 auto *II = dyn_cast<IntrinsicInst>(I); 2156 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0); 2157 SmallVector<Value *> Ops(NumOps); 2158 for (unsigned Idx = 0; Idx < NumOps; Idx++) { 2159 if (II && 2160 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) { 2161 Ops[Idx] = II->getOperand(Idx); 2162 continue; 2163 } 2164 Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), 2165 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs, 2166 Builder, TTI); 2167 } 2168 2169 SmallVector<Value *, 8> ValueList; 2170 for (const auto &Lane : Item) 2171 if (Lane.first) 2172 ValueList.push_back(Lane.first->get()); 2173 2174 Type *DstTy = 2175 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); 2176 if (auto *BI = dyn_cast<BinaryOperator>(I)) { 2177 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), 2178 Ops[0], Ops[1]); 2179 propagateIRFlags(Value, ValueList); 2180 return Value; 2181 } 2182 if (auto *CI = dyn_cast<CmpInst>(I)) { 2183 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]); 2184 propagateIRFlags(Value, ValueList); 2185 return Value; 2186 } 2187 if (auto *SI = dyn_cast<SelectInst>(I)) { 2188 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI); 2189 propagateIRFlags(Value, ValueList); 2190 return Value; 2191 } 2192 if (auto *CI = dyn_cast<CastInst>(I)) { 2193 auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(), 2194 Ops[0], DstTy); 2195 propagateIRFlags(Value, ValueList); 2196 return Value; 2197 } 2198 if (II) { 2199 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops); 2200 propagateIRFlags(Value, ValueList); 2201 return Value; 2202 } 2203 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate"); 2204 auto *Value = 2205 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]); 2206 propagateIRFlags(Value, ValueList); 2207 return Value; 2208 } 2209 2210 // Starting from a shuffle, look up through operands tracking the shuffled index 2211 // of each lane. If we can simplify away the shuffles to identities then 2212 // do so. 2213 bool VectorCombine::foldShuffleToIdentity(Instruction &I) { 2214 auto *Ty = dyn_cast<FixedVectorType>(I.getType()); 2215 if (!Ty || I.use_empty()) 2216 return false; 2217 2218 SmallVector<InstLane> Start(Ty->getNumElements()); 2219 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M) 2220 Start[M] = lookThroughShuffles(&*I.use_begin(), M); 2221 2222 SmallVector<SmallVector<InstLane>> Worklist; 2223 Worklist.push_back(Start); 2224 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; 2225 unsigned NumVisited = 0; 2226 2227 while (!Worklist.empty()) { 2228 if (++NumVisited > MaxInstrsToScan) 2229 return false; 2230 2231 SmallVector<InstLane> Item = Worklist.pop_back_val(); 2232 auto [FrontU, FrontLane] = Item.front(); 2233 2234 // If we found an undef first lane then bail out to keep things simple. 2235 if (!FrontU) 2236 return false; 2237 2238 // Helper to peek through bitcasts to the same value. 2239 auto IsEquiv = [&](Value *X, Value *Y) { 2240 return X->getType() == Y->getType() && 2241 peekThroughBitcasts(X) == peekThroughBitcasts(Y); 2242 }; 2243 2244 // Look for an identity value. 2245 if (FrontLane == 0 && 2246 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() == 2247 Ty->getNumElements() && 2248 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) { 2249 Value *FrontV = Item.front().first->get(); 2250 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) && 2251 E.value().second == (int)E.index()); 2252 })) { 2253 IdentityLeafs.insert(FrontU); 2254 continue; 2255 } 2256 // Look for constants, for the moment only supporting constant splats. 2257 if (auto *C = dyn_cast<Constant>(FrontU); 2258 C && C->getSplatValue() && 2259 all_of(drop_begin(Item), [Item](InstLane &IL) { 2260 Value *FrontV = Item.front().first->get(); 2261 Use *U = IL.first; 2262 return !U || U->get() == FrontV; 2263 })) { 2264 SplatLeafs.insert(FrontU); 2265 continue; 2266 } 2267 // Look for a splat value. 2268 if (all_of(drop_begin(Item), [Item](InstLane &IL) { 2269 auto [FrontU, FrontLane] = Item.front(); 2270 auto [U, Lane] = IL; 2271 return !U || (U->get() == FrontU->get() && Lane == FrontLane); 2272 })) { 2273 SplatLeafs.insert(FrontU); 2274 continue; 2275 } 2276 2277 // We need each element to be the same type of value, and check that each 2278 // element has a single use. 2279 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) { 2280 Value *FrontV = Item.front().first->get(); 2281 if (!IL.first) 2282 return true; 2283 Value *V = IL.first->get(); 2284 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse()) 2285 return false; 2286 if (V->getValueID() != FrontV->getValueID()) 2287 return false; 2288 if (auto *CI = dyn_cast<CmpInst>(V)) 2289 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate()) 2290 return false; 2291 if (auto *CI = dyn_cast<CastInst>(V)) 2292 if (CI->getSrcTy() != cast<CastInst>(FrontV)->getSrcTy()) 2293 return false; 2294 if (auto *SI = dyn_cast<SelectInst>(V)) 2295 if (!isa<VectorType>(SI->getOperand(0)->getType()) || 2296 SI->getOperand(0)->getType() != 2297 cast<SelectInst>(FrontV)->getOperand(0)->getType()) 2298 return false; 2299 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V)) 2300 return false; 2301 auto *II = dyn_cast<IntrinsicInst>(V); 2302 return !II || (isa<IntrinsicInst>(FrontV) && 2303 II->getIntrinsicID() == 2304 cast<IntrinsicInst>(FrontV)->getIntrinsicID() && 2305 !II->hasOperandBundles()); 2306 }; 2307 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) { 2308 // Check the operator is one that we support. 2309 if (isa<BinaryOperator, CmpInst>(FrontU)) { 2310 // We exclude div/rem in case they hit UB from poison lanes. 2311 if (auto *BO = dyn_cast<BinaryOperator>(FrontU); 2312 BO && BO->isIntDivRem()) 2313 return false; 2314 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2315 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2316 continue; 2317 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) { 2318 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2319 continue; 2320 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) { 2321 // TODO: Handle vector widening/narrowing bitcasts. 2322 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy()); 2323 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy()); 2324 if (DstTy && SrcTy && 2325 SrcTy->getNumElements() == DstTy->getNumElements()) { 2326 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2327 continue; 2328 } 2329 } else if (isa<SelectInst>(FrontU)) { 2330 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); 2331 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); 2332 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); 2333 continue; 2334 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); 2335 II && isTriviallyVectorizable(II->getIntrinsicID()) && 2336 !II->hasOperandBundles()) { 2337 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { 2338 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op, 2339 &TTI)) { 2340 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { 2341 Value *FrontV = Item.front().first->get(); 2342 Use *U = IL.first; 2343 return !U || (cast<Instruction>(U->get())->getOperand(Op) == 2344 cast<Instruction>(FrontV)->getOperand(Op)); 2345 })) 2346 return false; 2347 continue; 2348 } 2349 Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); 2350 } 2351 continue; 2352 } 2353 } 2354 2355 if (isFreeConcat(Item, CostKind, TTI)) { 2356 ConcatLeafs.insert(FrontU); 2357 continue; 2358 } 2359 2360 return false; 2361 } 2362 2363 if (NumVisited <= 1) 2364 return false; 2365 2366 // If we got this far, we know the shuffles are superfluous and can be 2367 // removed. Scan through again and generate the new tree of instructions. 2368 Builder.SetInsertPoint(&I); 2369 Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, 2370 ConcatLeafs, Builder, &TTI); 2371 replaceValue(I, *V); 2372 return true; 2373 } 2374 2375 /// Given a commutative reduction, the order of the input lanes does not alter 2376 /// the results. We can use this to remove certain shuffles feeding the 2377 /// reduction, removing the need to shuffle at all. 2378 bool VectorCombine::foldShuffleFromReductions(Instruction &I) { 2379 auto *II = dyn_cast<IntrinsicInst>(&I); 2380 if (!II) 2381 return false; 2382 switch (II->getIntrinsicID()) { 2383 case Intrinsic::vector_reduce_add: 2384 case Intrinsic::vector_reduce_mul: 2385 case Intrinsic::vector_reduce_and: 2386 case Intrinsic::vector_reduce_or: 2387 case Intrinsic::vector_reduce_xor: 2388 case Intrinsic::vector_reduce_smin: 2389 case Intrinsic::vector_reduce_smax: 2390 case Intrinsic::vector_reduce_umin: 2391 case Intrinsic::vector_reduce_umax: 2392 break; 2393 default: 2394 return false; 2395 } 2396 2397 // Find all the inputs when looking through operations that do not alter the 2398 // lane order (binops, for example). Currently we look for a single shuffle, 2399 // and can ignore splat values. 2400 std::queue<Value *> Worklist; 2401 SmallPtrSet<Value *, 4> Visited; 2402 ShuffleVectorInst *Shuffle = nullptr; 2403 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0))) 2404 Worklist.push(Op); 2405 2406 while (!Worklist.empty()) { 2407 Value *CV = Worklist.front(); 2408 Worklist.pop(); 2409 if (Visited.contains(CV)) 2410 continue; 2411 2412 // Splats don't change the order, so can be safely ignored. 2413 if (isSplatValue(CV)) 2414 continue; 2415 2416 Visited.insert(CV); 2417 2418 if (auto *CI = dyn_cast<Instruction>(CV)) { 2419 if (CI->isBinaryOp()) { 2420 for (auto *Op : CI->operand_values()) 2421 Worklist.push(Op); 2422 continue; 2423 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) { 2424 if (Shuffle && Shuffle != SV) 2425 return false; 2426 Shuffle = SV; 2427 continue; 2428 } 2429 } 2430 2431 // Anything else is currently an unknown node. 2432 return false; 2433 } 2434 2435 if (!Shuffle) 2436 return false; 2437 2438 // Check all uses of the binary ops and shuffles are also included in the 2439 // lane-invariant operations (Visited should be the list of lanewise 2440 // instructions, including the shuffle that we found). 2441 for (auto *V : Visited) 2442 for (auto *U : V->users()) 2443 if (!Visited.contains(U) && U != &I) 2444 return false; 2445 2446 FixedVectorType *VecType = 2447 dyn_cast<FixedVectorType>(II->getOperand(0)->getType()); 2448 if (!VecType) 2449 return false; 2450 FixedVectorType *ShuffleInputType = 2451 dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType()); 2452 if (!ShuffleInputType) 2453 return false; 2454 unsigned NumInputElts = ShuffleInputType->getNumElements(); 2455 2456 // Find the mask from sorting the lanes into order. This is most likely to 2457 // become a identity or concat mask. Undef elements are pushed to the end. 2458 SmallVector<int> ConcatMask; 2459 Shuffle->getShuffleMask(ConcatMask); 2460 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; }); 2461 // In the case of a truncating shuffle it's possible for the mask 2462 // to have an index greater than the size of the resulting vector. 2463 // This requires special handling. 2464 bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts; 2465 bool UsesSecondVec = 2466 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; }); 2467 2468 FixedVectorType *VecTyForCost = 2469 (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType; 2470 InstructionCost OldCost = TTI.getShuffleCost( 2471 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2472 VecTyForCost, Shuffle->getShuffleMask(), CostKind); 2473 InstructionCost NewCost = TTI.getShuffleCost( 2474 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, 2475 VecTyForCost, ConcatMask, CostKind); 2476 2477 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle 2478 << "\n"); 2479 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost 2480 << "\n"); 2481 if (NewCost < OldCost) { 2482 Builder.SetInsertPoint(Shuffle); 2483 Value *NewShuffle = Builder.CreateShuffleVector( 2484 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); 2485 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); 2486 replaceValue(*Shuffle, *NewShuffle); 2487 } 2488 2489 // See if we can re-use foldSelectShuffle, getting it to reduce the size of 2490 // the shuffle into a nicer order, as it can ignore the order of the shuffles. 2491 return foldSelectShuffle(*Shuffle, true); 2492 } 2493 2494 /// Determine if its more efficient to fold: 2495 /// reduce(trunc(x)) -> trunc(reduce(x)). 2496 /// reduce(sext(x)) -> sext(reduce(x)). 2497 /// reduce(zext(x)) -> zext(reduce(x)). 2498 bool VectorCombine::foldCastFromReductions(Instruction &I) { 2499 auto *II = dyn_cast<IntrinsicInst>(&I); 2500 if (!II) 2501 return false; 2502 2503 bool TruncOnly = false; 2504 Intrinsic::ID IID = II->getIntrinsicID(); 2505 switch (IID) { 2506 case Intrinsic::vector_reduce_add: 2507 case Intrinsic::vector_reduce_mul: 2508 TruncOnly = true; 2509 break; 2510 case Intrinsic::vector_reduce_and: 2511 case Intrinsic::vector_reduce_or: 2512 case Intrinsic::vector_reduce_xor: 2513 break; 2514 default: 2515 return false; 2516 } 2517 2518 unsigned ReductionOpc = getArithmeticReductionInstruction(IID); 2519 Value *ReductionSrc = I.getOperand(0); 2520 2521 Value *Src; 2522 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) && 2523 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src)))))) 2524 return false; 2525 2526 auto CastOpc = 2527 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode(); 2528 2529 auto *SrcTy = cast<VectorType>(Src->getType()); 2530 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType()); 2531 Type *ResultTy = I.getType(); 2532 2533 InstructionCost OldCost = TTI.getArithmeticReductionCost( 2534 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind); 2535 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy, 2536 TTI::CastContextHint::None, CostKind, 2537 cast<CastInst>(ReductionSrc)); 2538 InstructionCost NewCost = 2539 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt, 2540 CostKind) + 2541 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(), 2542 TTI::CastContextHint::None, CostKind); 2543 2544 if (OldCost <= NewCost || !NewCost.isValid()) 2545 return false; 2546 2547 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(), 2548 II->getIntrinsicID(), {Src}); 2549 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy); 2550 replaceValue(I, *NewCast); 2551 return true; 2552 } 2553 2554 /// This method looks for groups of shuffles acting on binops, of the form: 2555 /// %x = shuffle ... 2556 /// %y = shuffle ... 2557 /// %a = binop %x, %y 2558 /// %b = binop %x, %y 2559 /// shuffle %a, %b, selectmask 2560 /// We may, especially if the shuffle is wider than legal, be able to convert 2561 /// the shuffle to a form where only parts of a and b need to be computed. On 2562 /// architectures with no obvious "select" shuffle, this can reduce the total 2563 /// number of operations if the target reports them as cheaper. 2564 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { 2565 auto *SVI = cast<ShuffleVectorInst>(&I); 2566 auto *VT = cast<FixedVectorType>(I.getType()); 2567 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0)); 2568 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1)); 2569 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() || 2570 VT != Op0->getType()) 2571 return false; 2572 2573 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0)); 2574 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1)); 2575 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0)); 2576 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1)); 2577 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B}); 2578 auto checkSVNonOpUses = [&](Instruction *I) { 2579 if (!I || I->getOperand(0)->getType() != VT) 2580 return true; 2581 return any_of(I->users(), [&](User *U) { 2582 return U != Op0 && U != Op1 && 2583 !(isa<ShuffleVectorInst>(U) && 2584 (InputShuffles.contains(cast<Instruction>(U)) || 2585 isInstructionTriviallyDead(cast<Instruction>(U)))); 2586 }); 2587 }; 2588 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) || 2589 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B)) 2590 return false; 2591 2592 // Collect all the uses that are shuffles that we can transform together. We 2593 // may not have a single shuffle, but a group that can all be transformed 2594 // together profitably. 2595 SmallVector<ShuffleVectorInst *> Shuffles; 2596 auto collectShuffles = [&](Instruction *I) { 2597 for (auto *U : I->users()) { 2598 auto *SV = dyn_cast<ShuffleVectorInst>(U); 2599 if (!SV || SV->getType() != VT) 2600 return false; 2601 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) || 2602 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1)) 2603 return false; 2604 if (!llvm::is_contained(Shuffles, SV)) 2605 Shuffles.push_back(SV); 2606 } 2607 return true; 2608 }; 2609 if (!collectShuffles(Op0) || !collectShuffles(Op1)) 2610 return false; 2611 // From a reduction, we need to be processing a single shuffle, otherwise the 2612 // other uses will not be lane-invariant. 2613 if (FromReduction && Shuffles.size() > 1) 2614 return false; 2615 2616 // Add any shuffle uses for the shuffles we have found, to include them in our 2617 // cost calculations. 2618 if (!FromReduction) { 2619 for (ShuffleVectorInst *SV : Shuffles) { 2620 for (auto *U : SV->users()) { 2621 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U); 2622 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT) 2623 Shuffles.push_back(SSV); 2624 } 2625 } 2626 } 2627 2628 // For each of the output shuffles, we try to sort all the first vector 2629 // elements to the beginning, followed by the second array elements at the 2630 // end. If the binops are legalized to smaller vectors, this may reduce total 2631 // number of binops. We compute the ReconstructMask mask needed to convert 2632 // back to the original lane order. 2633 SmallVector<std::pair<int, int>> V1, V2; 2634 SmallVector<SmallVector<int>> OrigReconstructMasks; 2635 int MaxV1Elt = 0, MaxV2Elt = 0; 2636 unsigned NumElts = VT->getNumElements(); 2637 for (ShuffleVectorInst *SVN : Shuffles) { 2638 SmallVector<int> Mask; 2639 SVN->getShuffleMask(Mask); 2640 2641 // Check the operands are the same as the original, or reversed (in which 2642 // case we need to commute the mask). 2643 Value *SVOp0 = SVN->getOperand(0); 2644 Value *SVOp1 = SVN->getOperand(1); 2645 if (isa<UndefValue>(SVOp1)) { 2646 auto *SSV = cast<ShuffleVectorInst>(SVOp0); 2647 SVOp0 = SSV->getOperand(0); 2648 SVOp1 = SSV->getOperand(1); 2649 for (unsigned I = 0, E = Mask.size(); I != E; I++) { 2650 if (Mask[I] >= static_cast<int>(SSV->getShuffleMask().size())) 2651 return false; 2652 Mask[I] = Mask[I] < 0 ? Mask[I] : SSV->getMaskValue(Mask[I]); 2653 } 2654 } 2655 if (SVOp0 == Op1 && SVOp1 == Op0) { 2656 std::swap(SVOp0, SVOp1); 2657 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 2658 } 2659 if (SVOp0 != Op0 || SVOp1 != Op1) 2660 return false; 2661 2662 // Calculate the reconstruction mask for this shuffle, as the mask needed to 2663 // take the packed values from Op0/Op1 and reconstructing to the original 2664 // order. 2665 SmallVector<int> ReconstructMask; 2666 for (unsigned I = 0; I < Mask.size(); I++) { 2667 if (Mask[I] < 0) { 2668 ReconstructMask.push_back(-1); 2669 } else if (Mask[I] < static_cast<int>(NumElts)) { 2670 MaxV1Elt = std::max(MaxV1Elt, Mask[I]); 2671 auto It = find_if(V1, [&](const std::pair<int, int> &A) { 2672 return Mask[I] == A.first; 2673 }); 2674 if (It != V1.end()) 2675 ReconstructMask.push_back(It - V1.begin()); 2676 else { 2677 ReconstructMask.push_back(V1.size()); 2678 V1.emplace_back(Mask[I], V1.size()); 2679 } 2680 } else { 2681 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts); 2682 auto It = find_if(V2, [&](const std::pair<int, int> &A) { 2683 return Mask[I] - static_cast<int>(NumElts) == A.first; 2684 }); 2685 if (It != V2.end()) 2686 ReconstructMask.push_back(NumElts + It - V2.begin()); 2687 else { 2688 ReconstructMask.push_back(NumElts + V2.size()); 2689 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size()); 2690 } 2691 } 2692 } 2693 2694 // For reductions, we know that the lane ordering out doesn't alter the 2695 // result. In-order can help simplify the shuffle away. 2696 if (FromReduction) 2697 sort(ReconstructMask); 2698 OrigReconstructMasks.push_back(std::move(ReconstructMask)); 2699 } 2700 2701 // If the Maximum element used from V1 and V2 are not larger than the new 2702 // vectors, the vectors are already packes and performing the optimization 2703 // again will likely not help any further. This also prevents us from getting 2704 // stuck in a cycle in case the costs do not also rule it out. 2705 if (V1.empty() || V2.empty() || 2706 (MaxV1Elt == static_cast<int>(V1.size()) - 1 && 2707 MaxV2Elt == static_cast<int>(V2.size()) - 1)) 2708 return false; 2709 2710 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a 2711 // shuffle of another shuffle, or not a shuffle (that is treated like a 2712 // identity shuffle). 2713 auto GetBaseMaskValue = [&](Instruction *I, int M) { 2714 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2715 if (!SV) 2716 return M; 2717 if (isa<UndefValue>(SV->getOperand(1))) 2718 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2719 if (InputShuffles.contains(SSV)) 2720 return SSV->getMaskValue(SV->getMaskValue(M)); 2721 return SV->getMaskValue(M); 2722 }; 2723 2724 // Attempt to sort the inputs my ascending mask values to make simpler input 2725 // shuffles and push complex shuffles down to the uses. We sort on the first 2726 // of the two input shuffle orders, to try and get at least one input into a 2727 // nice order. 2728 auto SortBase = [&](Instruction *A, std::pair<int, int> X, 2729 std::pair<int, int> Y) { 2730 int MXA = GetBaseMaskValue(A, X.first); 2731 int MYA = GetBaseMaskValue(A, Y.first); 2732 return MXA < MYA; 2733 }; 2734 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) { 2735 return SortBase(SVI0A, A, B); 2736 }); 2737 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) { 2738 return SortBase(SVI1A, A, B); 2739 }); 2740 // Calculate our ReconstructMasks from the OrigReconstructMasks and the 2741 // modified order of the input shuffles. 2742 SmallVector<SmallVector<int>> ReconstructMasks; 2743 for (const auto &Mask : OrigReconstructMasks) { 2744 SmallVector<int> ReconstructMask; 2745 for (int M : Mask) { 2746 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) { 2747 auto It = find_if(V, [M](auto A) { return A.second == M; }); 2748 assert(It != V.end() && "Expected all entries in Mask"); 2749 return std::distance(V.begin(), It); 2750 }; 2751 if (M < 0) 2752 ReconstructMask.push_back(-1); 2753 else if (M < static_cast<int>(NumElts)) { 2754 ReconstructMask.push_back(FindIndex(V1, M)); 2755 } else { 2756 ReconstructMask.push_back(NumElts + FindIndex(V2, M)); 2757 } 2758 } 2759 ReconstructMasks.push_back(std::move(ReconstructMask)); 2760 } 2761 2762 // Calculate the masks needed for the new input shuffles, which get padded 2763 // with undef 2764 SmallVector<int> V1A, V1B, V2A, V2B; 2765 for (unsigned I = 0; I < V1.size(); I++) { 2766 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first)); 2767 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first)); 2768 } 2769 for (unsigned I = 0; I < V2.size(); I++) { 2770 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first)); 2771 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first)); 2772 } 2773 while (V1A.size() < NumElts) { 2774 V1A.push_back(PoisonMaskElem); 2775 V1B.push_back(PoisonMaskElem); 2776 } 2777 while (V2A.size() < NumElts) { 2778 V2A.push_back(PoisonMaskElem); 2779 V2B.push_back(PoisonMaskElem); 2780 } 2781 2782 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) { 2783 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2784 if (!SV) 2785 return C; 2786 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1)) 2787 ? TTI::SK_PermuteSingleSrc 2788 : TTI::SK_PermuteTwoSrc, 2789 VT, SV->getShuffleMask(), CostKind); 2790 }; 2791 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) { 2792 return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask, CostKind); 2793 }; 2794 2795 // Get the costs of the shuffles + binops before and after with the new 2796 // shuffle masks. 2797 InstructionCost CostBefore = 2798 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) + 2799 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind); 2800 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(), 2801 InstructionCost(0), AddShuffleCost); 2802 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(), 2803 InstructionCost(0), AddShuffleCost); 2804 2805 // The new binops will be unused for lanes past the used shuffle lengths. 2806 // These types attempt to get the correct cost for that from the target. 2807 FixedVectorType *Op0SmallVT = 2808 FixedVectorType::get(VT->getScalarType(), V1.size()); 2809 FixedVectorType *Op1SmallVT = 2810 FixedVectorType::get(VT->getScalarType(), V2.size()); 2811 InstructionCost CostAfter = 2812 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) + 2813 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind); 2814 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(), 2815 InstructionCost(0), AddShuffleMaskCost); 2816 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B}); 2817 CostAfter += 2818 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(), 2819 InstructionCost(0), AddShuffleMaskCost); 2820 2821 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n"); 2822 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore 2823 << " vs CostAfter: " << CostAfter << "\n"); 2824 if (CostBefore <= CostAfter) 2825 return false; 2826 2827 // The cost model has passed, create the new instructions. 2828 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * { 2829 auto *SV = dyn_cast<ShuffleVectorInst>(I); 2830 if (!SV) 2831 return I; 2832 if (isa<UndefValue>(SV->getOperand(1))) 2833 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0))) 2834 if (InputShuffles.contains(SSV)) 2835 return SSV->getOperand(Op); 2836 return SV->getOperand(Op); 2837 }; 2838 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef()); 2839 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0), 2840 GetShuffleOperand(SVI0A, 1), V1A); 2841 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef()); 2842 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0), 2843 GetShuffleOperand(SVI0B, 1), V1B); 2844 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef()); 2845 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0), 2846 GetShuffleOperand(SVI1A, 1), V2A); 2847 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef()); 2848 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0), 2849 GetShuffleOperand(SVI1B, 1), V2B); 2850 Builder.SetInsertPoint(Op0); 2851 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(), 2852 NSV0A, NSV0B); 2853 if (auto *I = dyn_cast<Instruction>(NOp0)) 2854 I->copyIRFlags(Op0, true); 2855 Builder.SetInsertPoint(Op1); 2856 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(), 2857 NSV1A, NSV1B); 2858 if (auto *I = dyn_cast<Instruction>(NOp1)) 2859 I->copyIRFlags(Op1, true); 2860 2861 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) { 2862 Builder.SetInsertPoint(Shuffles[S]); 2863 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]); 2864 replaceValue(*Shuffles[S], *NSV); 2865 } 2866 2867 Worklist.pushValue(NSV0A); 2868 Worklist.pushValue(NSV0B); 2869 Worklist.pushValue(NSV1A); 2870 Worklist.pushValue(NSV1B); 2871 for (auto *S : Shuffles) 2872 Worklist.add(S); 2873 return true; 2874 } 2875 2876 /// Check if instruction depends on ZExt and this ZExt can be moved after the 2877 /// instruction. Move ZExt if it is profitable. For example: 2878 /// logic(zext(x),y) -> zext(logic(x,trunc(y))) 2879 /// lshr((zext(x),y) -> zext(lshr(x,trunc(y))) 2880 /// Cost model calculations takes into account if zext(x) has other users and 2881 /// whether it can be propagated through them too. 2882 bool VectorCombine::shrinkType(Instruction &I) { 2883 Value *ZExted, *OtherOperand; 2884 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)), 2885 m_Value(OtherOperand))) && 2886 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand)))) 2887 return false; 2888 2889 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0); 2890 2891 auto *BigTy = cast<FixedVectorType>(I.getType()); 2892 auto *SmallTy = cast<FixedVectorType>(ZExted->getType()); 2893 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits(); 2894 2895 if (I.getOpcode() == Instruction::LShr) { 2896 // Check that the shift amount is less than the number of bits in the 2897 // smaller type. Otherwise, the smaller lshr will return a poison value. 2898 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL); 2899 if (ShAmtKB.getMaxValue().uge(BW)) 2900 return false; 2901 } else { 2902 // Check that the expression overall uses at most the same number of bits as 2903 // ZExted 2904 KnownBits KB = computeKnownBits(&I, *DL); 2905 if (KB.countMaxActiveBits() > BW) 2906 return false; 2907 } 2908 2909 // Calculate costs of leaving current IR as it is and moving ZExt operation 2910 // later, along with adding truncates if needed 2911 InstructionCost ZExtCost = TTI.getCastInstrCost( 2912 Instruction::ZExt, BigTy, SmallTy, 2913 TargetTransformInfo::CastContextHint::None, CostKind); 2914 InstructionCost CurrentCost = ZExtCost; 2915 InstructionCost ShrinkCost = 0; 2916 2917 // Calculate total cost and check that we can propagate through all ZExt users 2918 for (User *U : ZExtOperand->users()) { 2919 auto *UI = cast<Instruction>(U); 2920 if (UI == &I) { 2921 CurrentCost += 2922 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2923 ShrinkCost += 2924 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2925 ShrinkCost += ZExtCost; 2926 continue; 2927 } 2928 2929 if (!Instruction::isBinaryOp(UI->getOpcode())) 2930 return false; 2931 2932 // Check if we can propagate ZExt through its other users 2933 KnownBits KB = computeKnownBits(UI, *DL); 2934 if (KB.countMaxActiveBits() > BW) 2935 return false; 2936 2937 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind); 2938 ShrinkCost += 2939 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind); 2940 ShrinkCost += ZExtCost; 2941 } 2942 2943 // If the other instruction operand is not a constant, we'll need to 2944 // generate a truncate instruction. So we have to adjust cost 2945 if (!isa<Constant>(OtherOperand)) 2946 ShrinkCost += TTI.getCastInstrCost( 2947 Instruction::Trunc, SmallTy, BigTy, 2948 TargetTransformInfo::CastContextHint::None, CostKind); 2949 2950 // If the cost of shrinking types and leaving the IR is the same, we'll lean 2951 // towards modifying the IR because shrinking opens opportunities for other 2952 // shrinking optimisations. 2953 if (ShrinkCost > CurrentCost) 2954 return false; 2955 2956 Builder.SetInsertPoint(&I); 2957 Value *Op0 = ZExted; 2958 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy); 2959 // Keep the order of operands the same 2960 if (I.getOperand(0) == OtherOperand) 2961 std::swap(Op0, Op1); 2962 Value *NewBinOp = 2963 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1); 2964 cast<Instruction>(NewBinOp)->copyIRFlags(&I); 2965 cast<Instruction>(NewBinOp)->copyMetadata(I); 2966 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy); 2967 replaceValue(I, *NewZExtr); 2968 return true; 2969 } 2970 2971 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) --> 2972 /// shuffle (DstVec, SrcVec, Mask) 2973 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) { 2974 Value *DstVec, *SrcVec; 2975 uint64_t ExtIdx, InsIdx; 2976 if (!match(&I, 2977 m_InsertElt(m_Value(DstVec), 2978 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)), 2979 m_ConstantInt(InsIdx)))) 2980 return false; 2981 2982 auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); 2983 if (!VecTy || SrcVec->getType() != VecTy) 2984 return false; 2985 2986 unsigned NumElts = VecTy->getNumElements(); 2987 if (ExtIdx >= NumElts || InsIdx >= NumElts) 2988 return false; 2989 2990 SmallVector<int> Mask(NumElts, 0); 2991 std::iota(Mask.begin(), Mask.end(), 0); 2992 Mask[InsIdx] = ExtIdx + NumElts; 2993 // Cost 2994 auto *Ins = cast<InsertElementInst>(&I); 2995 auto *Ext = cast<ExtractElementInst>(I.getOperand(1)); 2996 2997 InstructionCost OldCost = 2998 TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx) + 2999 TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx); 3000 3001 InstructionCost NewCost = TTI.getShuffleCost( 3002 TargetTransformInfo::SK_PermuteTwoSrc, VecTy, Mask, CostKind); 3003 if (!Ext->hasOneUse()) 3004 NewCost += TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx); 3005 3006 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair : " << I 3007 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost 3008 << "\n"); 3009 3010 if (OldCost < NewCost) 3011 return false; 3012 3013 // Canonicalize undef param to RHS to help further folds. 3014 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) { 3015 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); 3016 std::swap(DstVec, SrcVec); 3017 } 3018 3019 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask); 3020 replaceValue(I, *Shuf); 3021 3022 return true; 3023 } 3024 3025 /// This is the entry point for all transforms. Pass manager differences are 3026 /// handled in the callers of this function. 3027 bool VectorCombine::run() { 3028 if (DisableVectorCombine) 3029 return false; 3030 3031 // Don't attempt vectorization if the target does not support vectors. 3032 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true))) 3033 return false; 3034 3035 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n"); 3036 3037 bool MadeChange = false; 3038 auto FoldInst = [this, &MadeChange](Instruction &I) { 3039 Builder.SetInsertPoint(&I); 3040 bool IsVectorType = isa<VectorType>(I.getType()); 3041 bool IsFixedVectorType = isa<FixedVectorType>(I.getType()); 3042 auto Opcode = I.getOpcode(); 3043 3044 // These folds should be beneficial regardless of when this pass is run 3045 // in the optimization pipeline. 3046 // The type checking is for run-time efficiency. We can avoid wasting time 3047 // dispatching to folding functions if there's no chance of matching. 3048 if (IsFixedVectorType) { 3049 switch (Opcode) { 3050 case Instruction::InsertElement: 3051 MadeChange |= vectorizeLoadInsert(I); 3052 break; 3053 case Instruction::ShuffleVector: 3054 MadeChange |= widenSubvectorLoad(I); 3055 break; 3056 default: 3057 break; 3058 } 3059 } 3060 3061 // This transform works with scalable and fixed vectors 3062 // TODO: Identify and allow other scalable transforms 3063 if (IsVectorType) { 3064 MadeChange |= scalarizeBinopOrCmp(I); 3065 MadeChange |= scalarizeLoadExtract(I); 3066 MadeChange |= scalarizeVPIntrinsic(I); 3067 } 3068 3069 if (Opcode == Instruction::Store) 3070 MadeChange |= foldSingleElementStore(I); 3071 3072 // If this is an early pipeline invocation of this pass, we are done. 3073 if (TryEarlyFoldsOnly) 3074 return; 3075 3076 // Otherwise, try folds that improve codegen but may interfere with 3077 // early IR canonicalizations. 3078 // The type checking is for run-time efficiency. We can avoid wasting time 3079 // dispatching to folding functions if there's no chance of matching. 3080 if (IsFixedVectorType) { 3081 switch (Opcode) { 3082 case Instruction::InsertElement: 3083 MadeChange |= foldInsExtFNeg(I); 3084 MadeChange |= foldInsExtVectorToShuffle(I); 3085 break; 3086 case Instruction::ShuffleVector: 3087 MadeChange |= foldPermuteOfBinops(I); 3088 MadeChange |= foldShuffleOfBinops(I); 3089 MadeChange |= foldShuffleOfCastops(I); 3090 MadeChange |= foldShuffleOfShuffles(I); 3091 MadeChange |= foldShuffleOfIntrinsics(I); 3092 MadeChange |= foldSelectShuffle(I); 3093 MadeChange |= foldShuffleToIdentity(I); 3094 break; 3095 case Instruction::BitCast: 3096 MadeChange |= foldBitcastShuffle(I); 3097 break; 3098 default: 3099 MadeChange |= shrinkType(I); 3100 break; 3101 } 3102 } else { 3103 switch (Opcode) { 3104 case Instruction::Call: 3105 MadeChange |= foldShuffleFromReductions(I); 3106 MadeChange |= foldCastFromReductions(I); 3107 break; 3108 case Instruction::ICmp: 3109 case Instruction::FCmp: 3110 MadeChange |= foldExtractExtract(I); 3111 break; 3112 case Instruction::Or: 3113 MadeChange |= foldConcatOfBoolMasks(I); 3114 [[fallthrough]]; 3115 default: 3116 if (Instruction::isBinaryOp(Opcode)) { 3117 MadeChange |= foldExtractExtract(I); 3118 MadeChange |= foldExtractedCmps(I); 3119 } 3120 break; 3121 } 3122 } 3123 }; 3124 3125 for (BasicBlock &BB : F) { 3126 // Ignore unreachable basic blocks. 3127 if (!DT.isReachableFromEntry(&BB)) 3128 continue; 3129 // Use early increment range so that we can erase instructions in loop. 3130 for (Instruction &I : make_early_inc_range(BB)) { 3131 if (I.isDebugOrPseudoInst()) 3132 continue; 3133 FoldInst(I); 3134 } 3135 } 3136 3137 while (!Worklist.isEmpty()) { 3138 Instruction *I = Worklist.removeOne(); 3139 if (!I) 3140 continue; 3141 3142 if (isInstructionTriviallyDead(I)) { 3143 eraseInstruction(*I); 3144 continue; 3145 } 3146 3147 FoldInst(*I); 3148 } 3149 3150 return MadeChange; 3151 } 3152 3153 PreservedAnalyses VectorCombinePass::run(Function &F, 3154 FunctionAnalysisManager &FAM) { 3155 auto &AC = FAM.getResult<AssumptionAnalysis>(F); 3156 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); 3157 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 3158 AAResults &AA = FAM.getResult<AAManager>(F); 3159 const DataLayout *DL = &F.getDataLayout(); 3160 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput, 3161 TryEarlyFoldsOnly); 3162 if (!Combiner.run()) 3163 return PreservedAnalyses::all(); 3164 PreservedAnalyses PA; 3165 PA.preserveSet<CFGAnalyses>(); 3166 return PA; 3167 } 3168