1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // intrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/BasicBlock.h" 22 #include "llvm/IR/Constant.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/DerivedTypes.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/Function.h" 27 #include "llvm/IR/IRBuilder.h" 28 #include "llvm/IR/Instruction.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/IR/Value.h" 33 #include "llvm/InitializePasses.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Support/Casting.h" 36 #include "llvm/Transforms/Scalar.h" 37 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 38 #include <cassert> 39 #include <optional> 40 41 using namespace llvm; 42 43 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 44 45 namespace { 46 47 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 48 public: 49 static char ID; // Pass identification, replacement for typeid 50 51 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 52 initializeScalarizeMaskedMemIntrinLegacyPassPass( 53 *PassRegistry::getPassRegistry()); 54 } 55 56 bool runOnFunction(Function &F) override; 57 58 StringRef getPassName() const override { 59 return "Scalarize Masked Memory Intrinsics"; 60 } 61 62 void getAnalysisUsage(AnalysisUsage &AU) const override { 63 AU.addRequired<TargetTransformInfoWrapperPass>(); 64 AU.addPreserved<DominatorTreeWrapperPass>(); 65 } 66 }; 67 68 } // end anonymous namespace 69 70 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 71 const TargetTransformInfo &TTI, const DataLayout &DL, 72 bool HasBranchDivergence, DomTreeUpdater *DTU); 73 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 74 const TargetTransformInfo &TTI, 75 const DataLayout &DL, bool HasBranchDivergence, 76 DomTreeUpdater *DTU); 77 78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 79 80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 81 "Scalarize unsupported masked memory intrinsics", false, 82 false) 83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 86 "Scalarize unsupported masked memory intrinsics", false, 87 false) 88 89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 90 return new ScalarizeMaskedMemIntrinLegacyPass(); 91 } 92 93 static bool isConstantIntVector(Value *Mask) { 94 Constant *C = dyn_cast<Constant>(Mask); 95 if (!C) 96 return false; 97 98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 99 for (unsigned i = 0; i != NumElts; ++i) { 100 Constant *CElt = C->getAggregateElement(i); 101 if (!CElt || !isa<ConstantInt>(CElt)) 102 return false; 103 } 104 105 return true; 106 } 107 108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, 109 unsigned Idx) { 110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; 111 } 112 113 // Translate a masked load intrinsic like 114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 115 // <16 x i1> %mask, <16 x i32> %passthru) 116 // to a chain of basic blocks, with loading element one-by-one if 117 // the appropriate mask bit is set 118 // 119 // %1 = bitcast i8* %addr to i32* 120 // %2 = extractelement <16 x i1> %mask, i32 0 121 // br i1 %2, label %cond.load, label %else 122 // 123 // cond.load: ; preds = %0 124 // %3 = getelementptr i32* %1, i32 0 125 // %4 = load i32* %3 126 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 127 // br label %else 128 // 129 // else: ; preds = %0, %cond.load 130 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] 131 // %6 = extractelement <16 x i1> %mask, i32 1 132 // br i1 %6, label %cond.load1, label %else2 133 // 134 // cond.load1: ; preds = %else 135 // %7 = getelementptr i32* %1, i32 1 136 // %8 = load i32* %7 137 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 138 // br label %else2 139 // 140 // else2: ; preds = %else, %cond.load1 141 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 142 // %10 = extractelement <16 x i1> %mask, i32 2 143 // br i1 %10, label %cond.load4, label %else5 144 // 145 static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, 146 CallInst *CI, DomTreeUpdater *DTU, 147 bool &ModifiedDT) { 148 Value *Ptr = CI->getArgOperand(0); 149 Value *Alignment = CI->getArgOperand(1); 150 Value *Mask = CI->getArgOperand(2); 151 Value *Src0 = CI->getArgOperand(3); 152 153 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 154 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 155 156 Type *EltTy = VecType->getElementType(); 157 158 IRBuilder<> Builder(CI->getContext()); 159 Instruction *InsertPt = CI; 160 BasicBlock *IfBlock = CI->getParent(); 161 162 Builder.SetInsertPoint(InsertPt); 163 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 164 165 // Short-cut if the mask is all-true. 166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 167 LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 168 NewI->copyMetadata(*CI); 169 NewI->takeName(CI); 170 CI->replaceAllUsesWith(NewI); 171 CI->eraseFromParent(); 172 return; 173 } 174 175 // Adjust alignment for the scalar instruction. 176 const Align AdjustedAlignVal = 177 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 179 180 // The result vector 181 Value *VResult = Src0; 182 183 if (isConstantIntVector(Mask)) { 184 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 185 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 186 continue; 187 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 188 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 189 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 190 } 191 CI->replaceAllUsesWith(VResult); 192 CI->eraseFromParent(); 193 return; 194 } 195 196 // Optimize the case where the "masked load" is a predicated load - that is, 197 // where the mask is the splat of a non-constant scalar boolean. In that case, 198 // use that splated value as the guard on a conditional vector load. 199 if (isSplatValue(Mask, /*Index=*/0)) { 200 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull), 201 Mask->getName() + ".first"); 202 Instruction *ThenTerm = 203 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 204 /*BranchWeights=*/nullptr, DTU); 205 206 BasicBlock *CondBlock = ThenTerm->getParent(); 207 CondBlock->setName("cond.load"); 208 Builder.SetInsertPoint(CondBlock->getTerminator()); 209 LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal, 210 CI->getName() + ".cond.load"); 211 Load->copyMetadata(*CI); 212 213 BasicBlock *PostLoad = ThenTerm->getSuccessor(0); 214 Builder.SetInsertPoint(PostLoad, PostLoad->begin()); 215 PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2); 216 Phi->addIncoming(Load, CondBlock); 217 Phi->addIncoming(Src0, IfBlock); 218 Phi->takeName(CI); 219 220 CI->replaceAllUsesWith(Phi); 221 CI->eraseFromParent(); 222 ModifiedDT = true; 223 return; 224 } 225 // If the mask is not v1i1, use scalar bit test operations. This generates 226 // better results on X86 at least. However, don't do this on GPUs and other 227 // machines with divergence, as there each i1 needs a vector register. 228 Value *SclrMask = nullptr; 229 if (VectorWidth != 1 && !HasBranchDivergence) { 230 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 231 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 232 } 233 234 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 235 // Fill the "else" block, created in the previous iteration 236 // 237 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, 238 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16 239 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else 240 // 241 // On GPUs, use 242 // %cond = extrectelement %mask, Idx 243 // instead 244 Value *Predicate; 245 if (SclrMask != nullptr) { 246 Value *Mask = Builder.getInt(APInt::getOneBitSet( 247 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 248 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 249 Builder.getIntN(VectorWidth, 0)); 250 } else { 251 Predicate = Builder.CreateExtractElement(Mask, Idx); 252 } 253 254 // Create "cond" block 255 // 256 // %EltAddr = getelementptr i32* %1, i32 0 257 // %Elt = load i32* %EltAddr 258 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 259 // 260 Instruction *ThenTerm = 261 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 262 /*BranchWeights=*/nullptr, DTU); 263 264 BasicBlock *CondBlock = ThenTerm->getParent(); 265 CondBlock->setName("cond.load"); 266 267 Builder.SetInsertPoint(CondBlock->getTerminator()); 268 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 269 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 270 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 271 272 // Create "else" block, fill it in the next iteration 273 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 274 NewIfBlock->setName("else"); 275 BasicBlock *PrevIfBlock = IfBlock; 276 IfBlock = NewIfBlock; 277 278 // Create the phi to join the new and previous value. 279 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 280 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 281 Phi->addIncoming(NewVResult, CondBlock); 282 Phi->addIncoming(VResult, PrevIfBlock); 283 VResult = Phi; 284 } 285 286 CI->replaceAllUsesWith(VResult); 287 CI->eraseFromParent(); 288 289 ModifiedDT = true; 290 } 291 292 // Translate a masked store intrinsic, like 293 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 294 // <16 x i1> %mask) 295 // to a chain of basic blocks, that stores element one-by-one if 296 // the appropriate mask bit is set 297 // 298 // %1 = bitcast i8* %addr to i32* 299 // %2 = extractelement <16 x i1> %mask, i32 0 300 // br i1 %2, label %cond.store, label %else 301 // 302 // cond.store: ; preds = %0 303 // %3 = extractelement <16 x i32> %val, i32 0 304 // %4 = getelementptr i32* %1, i32 0 305 // store i32 %3, i32* %4 306 // br label %else 307 // 308 // else: ; preds = %0, %cond.store 309 // %5 = extractelement <16 x i1> %mask, i32 1 310 // br i1 %5, label %cond.store1, label %else2 311 // 312 // cond.store1: ; preds = %else 313 // %6 = extractelement <16 x i32> %val, i32 1 314 // %7 = getelementptr i32* %1, i32 1 315 // store i32 %6, i32* %7 316 // br label %else2 317 // . . . 318 static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, 319 CallInst *CI, DomTreeUpdater *DTU, 320 bool &ModifiedDT) { 321 Value *Src = CI->getArgOperand(0); 322 Value *Ptr = CI->getArgOperand(1); 323 Value *Alignment = CI->getArgOperand(2); 324 Value *Mask = CI->getArgOperand(3); 325 326 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 327 auto *VecType = cast<VectorType>(Src->getType()); 328 329 Type *EltTy = VecType->getElementType(); 330 331 IRBuilder<> Builder(CI->getContext()); 332 Instruction *InsertPt = CI; 333 Builder.SetInsertPoint(InsertPt); 334 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 335 336 // Short-cut if the mask is all-true. 337 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 338 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal); 339 Store->takeName(CI); 340 Store->copyMetadata(*CI); 341 CI->eraseFromParent(); 342 return; 343 } 344 345 // Adjust alignment for the scalar instruction. 346 const Align AdjustedAlignVal = 347 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 348 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 349 350 if (isConstantIntVector(Mask)) { 351 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 352 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 353 continue; 354 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 355 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 356 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 357 } 358 CI->eraseFromParent(); 359 return; 360 } 361 362 // Optimize the case where the "masked store" is a predicated store - that is, 363 // when the mask is the splat of a non-constant scalar boolean. In that case, 364 // optimize to a conditional store. 365 if (isSplatValue(Mask, /*Index=*/0)) { 366 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull), 367 Mask->getName() + ".first"); 368 Instruction *ThenTerm = 369 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 370 /*BranchWeights=*/nullptr, DTU); 371 BasicBlock *CondBlock = ThenTerm->getParent(); 372 CondBlock->setName("cond.store"); 373 Builder.SetInsertPoint(CondBlock->getTerminator()); 374 375 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal); 376 Store->takeName(CI); 377 Store->copyMetadata(*CI); 378 379 CI->eraseFromParent(); 380 ModifiedDT = true; 381 return; 382 } 383 384 // If the mask is not v1i1, use scalar bit test operations. This generates 385 // better results on X86 at least. However, don't do this on GPUs or other 386 // machines with branch divergence, as there each i1 takes up a register. 387 Value *SclrMask = nullptr; 388 if (VectorWidth != 1 && !HasBranchDivergence) { 389 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 390 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 391 } 392 393 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 394 // Fill the "else" block, created in the previous iteration 395 // 396 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 397 // %cond = icmp ne i16 %mask_1, 0 398 // br i1 %mask_1, label %cond.store, label %else 399 // 400 // On GPUs, use 401 // %cond = extrectelement %mask, Idx 402 // instead 403 Value *Predicate; 404 if (SclrMask != nullptr) { 405 Value *Mask = Builder.getInt(APInt::getOneBitSet( 406 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 407 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 408 Builder.getIntN(VectorWidth, 0)); 409 } else { 410 Predicate = Builder.CreateExtractElement(Mask, Idx); 411 } 412 413 // Create "cond" block 414 // 415 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 416 // %EltAddr = getelementptr i32* %1, i32 0 417 // %store i32 %OneElt, i32* %EltAddr 418 // 419 Instruction *ThenTerm = 420 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 421 /*BranchWeights=*/nullptr, DTU); 422 423 BasicBlock *CondBlock = ThenTerm->getParent(); 424 CondBlock->setName("cond.store"); 425 426 Builder.SetInsertPoint(CondBlock->getTerminator()); 427 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 428 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 429 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 430 431 // Create "else" block, fill it in the next iteration 432 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 433 NewIfBlock->setName("else"); 434 435 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 436 } 437 CI->eraseFromParent(); 438 439 ModifiedDT = true; 440 } 441 442 // Translate a masked gather intrinsic like 443 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 444 // <16 x i1> %Mask, <16 x i32> %Src) 445 // to a chain of basic blocks, with loading element one-by-one if 446 // the appropriate mask bit is set 447 // 448 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 449 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 450 // br i1 %Mask0, label %cond.load, label %else 451 // 452 // cond.load: 453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 454 // %Load0 = load i32, i32* %Ptr0, align 4 455 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 456 // br label %else 457 // 458 // else: 459 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] 460 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 461 // br i1 %Mask1, label %cond.load1, label %else2 462 // 463 // cond.load1: 464 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 465 // %Load1 = load i32, i32* %Ptr1, align 4 466 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 467 // br label %else2 468 // . . . 469 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 470 // ret <16 x i32> %Result 471 static void scalarizeMaskedGather(const DataLayout &DL, 472 bool HasBranchDivergence, CallInst *CI, 473 DomTreeUpdater *DTU, bool &ModifiedDT) { 474 Value *Ptrs = CI->getArgOperand(0); 475 Value *Alignment = CI->getArgOperand(1); 476 Value *Mask = CI->getArgOperand(2); 477 Value *Src0 = CI->getArgOperand(3); 478 479 auto *VecType = cast<FixedVectorType>(CI->getType()); 480 Type *EltTy = VecType->getElementType(); 481 482 IRBuilder<> Builder(CI->getContext()); 483 Instruction *InsertPt = CI; 484 BasicBlock *IfBlock = CI->getParent(); 485 Builder.SetInsertPoint(InsertPt); 486 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 487 488 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 489 490 // The result vector 491 Value *VResult = Src0; 492 unsigned VectorWidth = VecType->getNumElements(); 493 494 // Shorten the way if the mask is a vector of constants. 495 if (isConstantIntVector(Mask)) { 496 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 497 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 498 continue; 499 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 500 LoadInst *Load = 501 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 502 VResult = 503 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 504 } 505 CI->replaceAllUsesWith(VResult); 506 CI->eraseFromParent(); 507 return; 508 } 509 510 // If the mask is not v1i1, use scalar bit test operations. This generates 511 // better results on X86 at least. However, don't do this on GPUs or other 512 // machines with branch divergence, as there, each i1 takes up a register. 513 Value *SclrMask = nullptr; 514 if (VectorWidth != 1 && !HasBranchDivergence) { 515 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 516 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 517 } 518 519 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 520 // Fill the "else" block, created in the previous iteration 521 // 522 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 523 // %cond = icmp ne i16 %mask_1, 0 524 // br i1 %Mask1, label %cond.load, label %else 525 // 526 // On GPUs, use 527 // %cond = extrectelement %mask, Idx 528 // instead 529 530 Value *Predicate; 531 if (SclrMask != nullptr) { 532 Value *Mask = Builder.getInt(APInt::getOneBitSet( 533 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 534 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 535 Builder.getIntN(VectorWidth, 0)); 536 } else { 537 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 538 } 539 540 // Create "cond" block 541 // 542 // %EltAddr = getelementptr i32* %1, i32 0 543 // %Elt = load i32* %EltAddr 544 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 545 // 546 Instruction *ThenTerm = 547 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 548 /*BranchWeights=*/nullptr, DTU); 549 550 BasicBlock *CondBlock = ThenTerm->getParent(); 551 CondBlock->setName("cond.load"); 552 553 Builder.SetInsertPoint(CondBlock->getTerminator()); 554 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 555 LoadInst *Load = 556 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 557 Value *NewVResult = 558 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 559 560 // Create "else" block, fill it in the next iteration 561 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 562 NewIfBlock->setName("else"); 563 BasicBlock *PrevIfBlock = IfBlock; 564 IfBlock = NewIfBlock; 565 566 // Create the phi to join the new and previous value. 567 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 568 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 569 Phi->addIncoming(NewVResult, CondBlock); 570 Phi->addIncoming(VResult, PrevIfBlock); 571 VResult = Phi; 572 } 573 574 CI->replaceAllUsesWith(VResult); 575 CI->eraseFromParent(); 576 577 ModifiedDT = true; 578 } 579 580 // Translate a masked scatter intrinsic, like 581 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 582 // <16 x i1> %Mask) 583 // to a chain of basic blocks, that stores element one-by-one if 584 // the appropriate mask bit is set. 585 // 586 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 587 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 588 // br i1 %Mask0, label %cond.store, label %else 589 // 590 // cond.store: 591 // %Elt0 = extractelement <16 x i32> %Src, i32 0 592 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 593 // store i32 %Elt0, i32* %Ptr0, align 4 594 // br label %else 595 // 596 // else: 597 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 598 // br i1 %Mask1, label %cond.store1, label %else2 599 // 600 // cond.store1: 601 // %Elt1 = extractelement <16 x i32> %Src, i32 1 602 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 603 // store i32 %Elt1, i32* %Ptr1, align 4 604 // br label %else2 605 // . . . 606 static void scalarizeMaskedScatter(const DataLayout &DL, 607 bool HasBranchDivergence, CallInst *CI, 608 DomTreeUpdater *DTU, bool &ModifiedDT) { 609 Value *Src = CI->getArgOperand(0); 610 Value *Ptrs = CI->getArgOperand(1); 611 Value *Alignment = CI->getArgOperand(2); 612 Value *Mask = CI->getArgOperand(3); 613 614 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 615 616 assert( 617 isa<VectorType>(Ptrs->getType()) && 618 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 619 "Vector of pointers is expected in masked scatter intrinsic"); 620 621 IRBuilder<> Builder(CI->getContext()); 622 Instruction *InsertPt = CI; 623 Builder.SetInsertPoint(InsertPt); 624 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 625 626 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 627 unsigned VectorWidth = SrcFVTy->getNumElements(); 628 629 // Shorten the way if the mask is a vector of constants. 630 if (isConstantIntVector(Mask)) { 631 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 632 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 633 continue; 634 Value *OneElt = 635 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 636 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 637 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 638 } 639 CI->eraseFromParent(); 640 return; 641 } 642 643 // If the mask is not v1i1, use scalar bit test operations. This generates 644 // better results on X86 at least. 645 Value *SclrMask = nullptr; 646 if (VectorWidth != 1 && !HasBranchDivergence) { 647 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 648 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 649 } 650 651 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 652 // Fill the "else" block, created in the previous iteration 653 // 654 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 655 // %cond = icmp ne i16 %mask_1, 0 656 // br i1 %Mask1, label %cond.store, label %else 657 // 658 // On GPUs, use 659 // %cond = extrectelement %mask, Idx 660 // instead 661 Value *Predicate; 662 if (SclrMask != nullptr) { 663 Value *Mask = Builder.getInt(APInt::getOneBitSet( 664 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 665 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 666 Builder.getIntN(VectorWidth, 0)); 667 } else { 668 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 669 } 670 671 // Create "cond" block 672 // 673 // %Elt1 = extractelement <16 x i32> %Src, i32 1 674 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 675 // %store i32 %Elt1, i32* %Ptr1 676 // 677 Instruction *ThenTerm = 678 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 679 /*BranchWeights=*/nullptr, DTU); 680 681 BasicBlock *CondBlock = ThenTerm->getParent(); 682 CondBlock->setName("cond.store"); 683 684 Builder.SetInsertPoint(CondBlock->getTerminator()); 685 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 686 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 687 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 688 689 // Create "else" block, fill it in the next iteration 690 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 691 NewIfBlock->setName("else"); 692 693 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 694 } 695 CI->eraseFromParent(); 696 697 ModifiedDT = true; 698 } 699 700 static void scalarizeMaskedExpandLoad(const DataLayout &DL, 701 bool HasBranchDivergence, CallInst *CI, 702 DomTreeUpdater *DTU, bool &ModifiedDT) { 703 Value *Ptr = CI->getArgOperand(0); 704 Value *Mask = CI->getArgOperand(1); 705 Value *PassThru = CI->getArgOperand(2); 706 Align Alignment = CI->getParamAlign(0).valueOrOne(); 707 708 auto *VecType = cast<FixedVectorType>(CI->getType()); 709 710 Type *EltTy = VecType->getElementType(); 711 712 IRBuilder<> Builder(CI->getContext()); 713 Instruction *InsertPt = CI; 714 BasicBlock *IfBlock = CI->getParent(); 715 716 Builder.SetInsertPoint(InsertPt); 717 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 718 719 unsigned VectorWidth = VecType->getNumElements(); 720 721 // The result vector 722 Value *VResult = PassThru; 723 724 // Adjust alignment for the scalar instruction. 725 const Align AdjustedAlignment = 726 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8); 727 728 // Shorten the way if the mask is a vector of constants. 729 // Create a build_vector pattern, with loads/poisons as necessary and then 730 // shuffle blend with the pass through value. 731 if (isConstantIntVector(Mask)) { 732 unsigned MemIndex = 0; 733 VResult = PoisonValue::get(VecType); 734 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem); 735 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 736 Value *InsertElt; 737 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 738 InsertElt = PoisonValue::get(EltTy); 739 ShuffleMask[Idx] = Idx + VectorWidth; 740 } else { 741 Value *NewPtr = 742 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 743 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment, 744 "Load" + Twine(Idx)); 745 ShuffleMask[Idx] = Idx; 746 ++MemIndex; 747 } 748 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 749 "Res" + Twine(Idx)); 750 } 751 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 752 CI->replaceAllUsesWith(VResult); 753 CI->eraseFromParent(); 754 return; 755 } 756 757 // If the mask is not v1i1, use scalar bit test operations. This generates 758 // better results on X86 at least. However, don't do this on GPUs or other 759 // machines with branch divergence, as there, each i1 takes up a register. 760 Value *SclrMask = nullptr; 761 if (VectorWidth != 1 && !HasBranchDivergence) { 762 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 763 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 764 } 765 766 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 767 // Fill the "else" block, created in the previous iteration 768 // 769 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, 770 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1, 771 // label %cond.load, label %else 772 // 773 // On GPUs, use 774 // %cond = extrectelement %mask, Idx 775 // instead 776 777 Value *Predicate; 778 if (SclrMask != nullptr) { 779 Value *Mask = Builder.getInt(APInt::getOneBitSet( 780 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 781 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 782 Builder.getIntN(VectorWidth, 0)); 783 } else { 784 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 785 } 786 787 // Create "cond" block 788 // 789 // %EltAddr = getelementptr i32* %1, i32 0 790 // %Elt = load i32* %EltAddr 791 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 792 // 793 Instruction *ThenTerm = 794 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 795 /*BranchWeights=*/nullptr, DTU); 796 797 BasicBlock *CondBlock = ThenTerm->getParent(); 798 CondBlock->setName("cond.load"); 799 800 Builder.SetInsertPoint(CondBlock->getTerminator()); 801 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment); 802 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 803 804 // Move the pointer if there are more blocks to come. 805 Value *NewPtr; 806 if ((Idx + 1) != VectorWidth) 807 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 808 809 // Create "else" block, fill it in the next iteration 810 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 811 NewIfBlock->setName("else"); 812 BasicBlock *PrevIfBlock = IfBlock; 813 IfBlock = NewIfBlock; 814 815 // Create the phi to join the new and previous value. 816 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 817 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 818 ResultPhi->addIncoming(NewVResult, CondBlock); 819 ResultPhi->addIncoming(VResult, PrevIfBlock); 820 VResult = ResultPhi; 821 822 // Add a PHI for the pointer if this isn't the last iteration. 823 if ((Idx + 1) != VectorWidth) { 824 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 825 PtrPhi->addIncoming(NewPtr, CondBlock); 826 PtrPhi->addIncoming(Ptr, PrevIfBlock); 827 Ptr = PtrPhi; 828 } 829 } 830 831 CI->replaceAllUsesWith(VResult); 832 CI->eraseFromParent(); 833 834 ModifiedDT = true; 835 } 836 837 static void scalarizeMaskedCompressStore(const DataLayout &DL, 838 bool HasBranchDivergence, CallInst *CI, 839 DomTreeUpdater *DTU, 840 bool &ModifiedDT) { 841 Value *Src = CI->getArgOperand(0); 842 Value *Ptr = CI->getArgOperand(1); 843 Value *Mask = CI->getArgOperand(2); 844 Align Alignment = CI->getParamAlign(1).valueOrOne(); 845 846 auto *VecType = cast<FixedVectorType>(Src->getType()); 847 848 IRBuilder<> Builder(CI->getContext()); 849 Instruction *InsertPt = CI; 850 BasicBlock *IfBlock = CI->getParent(); 851 852 Builder.SetInsertPoint(InsertPt); 853 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 854 855 Type *EltTy = VecType->getElementType(); 856 857 // Adjust alignment for the scalar instruction. 858 const Align AdjustedAlignment = 859 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8); 860 861 unsigned VectorWidth = VecType->getNumElements(); 862 863 // Shorten the way if the mask is a vector of constants. 864 if (isConstantIntVector(Mask)) { 865 unsigned MemIndex = 0; 866 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 867 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 868 continue; 869 Value *OneElt = 870 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 871 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 872 Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment); 873 ++MemIndex; 874 } 875 CI->eraseFromParent(); 876 return; 877 } 878 879 // If the mask is not v1i1, use scalar bit test operations. This generates 880 // better results on X86 at least. However, don't do this on GPUs or other 881 // machines with branch divergence, as there, each i1 takes up a register. 882 Value *SclrMask = nullptr; 883 if (VectorWidth != 1 && !HasBranchDivergence) { 884 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 885 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 886 } 887 888 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 889 // Fill the "else" block, created in the previous iteration 890 // 891 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 892 // br i1 %mask_1, label %cond.store, label %else 893 // 894 // On GPUs, use 895 // %cond = extrectelement %mask, Idx 896 // instead 897 Value *Predicate; 898 if (SclrMask != nullptr) { 899 Value *Mask = Builder.getInt(APInt::getOneBitSet( 900 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 901 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 902 Builder.getIntN(VectorWidth, 0)); 903 } else { 904 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 905 } 906 907 // Create "cond" block 908 // 909 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 910 // %EltAddr = getelementptr i32* %1, i32 0 911 // %store i32 %OneElt, i32* %EltAddr 912 // 913 Instruction *ThenTerm = 914 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 915 /*BranchWeights=*/nullptr, DTU); 916 917 BasicBlock *CondBlock = ThenTerm->getParent(); 918 CondBlock->setName("cond.store"); 919 920 Builder.SetInsertPoint(CondBlock->getTerminator()); 921 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 922 Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment); 923 924 // Move the pointer if there are more blocks to come. 925 Value *NewPtr; 926 if ((Idx + 1) != VectorWidth) 927 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 928 929 // Create "else" block, fill it in the next iteration 930 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 931 NewIfBlock->setName("else"); 932 BasicBlock *PrevIfBlock = IfBlock; 933 IfBlock = NewIfBlock; 934 935 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 936 937 // Add a PHI for the pointer if this isn't the last iteration. 938 if ((Idx + 1) != VectorWidth) { 939 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 940 PtrPhi->addIncoming(NewPtr, CondBlock); 941 PtrPhi->addIncoming(Ptr, PrevIfBlock); 942 Ptr = PtrPhi; 943 } 944 } 945 CI->eraseFromParent(); 946 947 ModifiedDT = true; 948 } 949 950 static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, 951 DomTreeUpdater *DTU, 952 bool &ModifiedDT) { 953 // If we extend histogram to return a result someday (like the updated vector) 954 // then we'll need to support it here. 955 assert(CI->getType()->isVoidTy() && "Histogram with non-void return."); 956 Value *Ptrs = CI->getArgOperand(0); 957 Value *Inc = CI->getArgOperand(1); 958 Value *Mask = CI->getArgOperand(2); 959 960 auto *AddrType = cast<FixedVectorType>(Ptrs->getType()); 961 Type *EltTy = Inc->getType(); 962 963 IRBuilder<> Builder(CI->getContext()); 964 Instruction *InsertPt = CI; 965 Builder.SetInsertPoint(InsertPt); 966 967 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 968 969 // FIXME: Do we need to add an alignment parameter to the intrinsic? 970 unsigned VectorWidth = AddrType->getNumElements(); 971 972 // Shorten the way if the mask is a vector of constants. 973 if (isConstantIntVector(Mask)) { 974 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 975 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 976 continue; 977 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 978 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx)); 979 Value *Add = Builder.CreateAdd(Load, Inc); 980 Builder.CreateStore(Add, Ptr); 981 } 982 CI->eraseFromParent(); 983 return; 984 } 985 986 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 987 Value *Predicate = 988 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 989 990 Instruction *ThenTerm = 991 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 992 /*BranchWeights=*/nullptr, DTU); 993 994 BasicBlock *CondBlock = ThenTerm->getParent(); 995 CondBlock->setName("cond.histogram.update"); 996 997 Builder.SetInsertPoint(CondBlock->getTerminator()); 998 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 999 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx)); 1000 Value *Add = Builder.CreateAdd(Load, Inc); 1001 Builder.CreateStore(Add, Ptr); 1002 1003 // Create "else" block, fill it in the next iteration 1004 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 1005 NewIfBlock->setName("else"); 1006 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 1007 } 1008 1009 CI->eraseFromParent(); 1010 ModifiedDT = true; 1011 } 1012 1013 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 1014 DominatorTree *DT) { 1015 std::optional<DomTreeUpdater> DTU; 1016 if (DT) 1017 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 1018 1019 bool EverMadeChange = false; 1020 bool MadeChange = true; 1021 auto &DL = F.getDataLayout(); 1022 bool HasBranchDivergence = TTI.hasBranchDivergence(&F); 1023 while (MadeChange) { 1024 MadeChange = false; 1025 for (BasicBlock &BB : llvm::make_early_inc_range(F)) { 1026 bool ModifiedDTOnIteration = false; 1027 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, 1028 HasBranchDivergence, DTU ? &*DTU : nullptr); 1029 1030 // Restart BB iteration if the dominator tree of the Function was changed 1031 if (ModifiedDTOnIteration) 1032 break; 1033 } 1034 1035 EverMadeChange |= MadeChange; 1036 } 1037 return EverMadeChange; 1038 } 1039 1040 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 1041 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 1042 DominatorTree *DT = nullptr; 1043 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 1044 DT = &DTWP->getDomTree(); 1045 return runImpl(F, TTI, DT); 1046 } 1047 1048 PreservedAnalyses 1049 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 1050 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1051 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 1052 if (!runImpl(F, TTI, DT)) 1053 return PreservedAnalyses::all(); 1054 PreservedAnalyses PA; 1055 PA.preserve<TargetIRAnalysis>(); 1056 PA.preserve<DominatorTreeAnalysis>(); 1057 return PA; 1058 } 1059 1060 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 1061 const TargetTransformInfo &TTI, const DataLayout &DL, 1062 bool HasBranchDivergence, DomTreeUpdater *DTU) { 1063 bool MadeChange = false; 1064 1065 BasicBlock::iterator CurInstIterator = BB.begin(); 1066 while (CurInstIterator != BB.end()) { 1067 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 1068 MadeChange |= 1069 optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU); 1070 if (ModifiedDT) 1071 return true; 1072 } 1073 1074 return MadeChange; 1075 } 1076 1077 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 1078 const TargetTransformInfo &TTI, 1079 const DataLayout &DL, bool HasBranchDivergence, 1080 DomTreeUpdater *DTU) { 1081 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 1082 if (II) { 1083 // The scalarization code below does not work for scalable vectors. 1084 if (isa<ScalableVectorType>(II->getType()) || 1085 any_of(II->args(), 1086 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 1087 return false; 1088 switch (II->getIntrinsicID()) { 1089 default: 1090 break; 1091 case Intrinsic::experimental_vector_histogram_add: 1092 if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(), 1093 CI->getArgOperand(1)->getType())) 1094 return false; 1095 scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT); 1096 return true; 1097 case Intrinsic::masked_load: 1098 // Scalarize unsupported vector masked load 1099 if (TTI.isLegalMaskedLoad( 1100 CI->getType(), 1101 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 1102 return false; 1103 scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); 1104 return true; 1105 case Intrinsic::masked_store: 1106 if (TTI.isLegalMaskedStore( 1107 CI->getArgOperand(0)->getType(), 1108 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 1109 return false; 1110 scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT); 1111 return true; 1112 case Intrinsic::masked_gather: { 1113 MaybeAlign MA = 1114 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(); 1115 Type *LoadTy = CI->getType(); 1116 Align Alignment = DL.getValueOrABITypeAlignment(MA, 1117 LoadTy->getScalarType()); 1118 if (TTI.isLegalMaskedGather(LoadTy, Alignment) && 1119 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment)) 1120 return false; 1121 scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT); 1122 return true; 1123 } 1124 case Intrinsic::masked_scatter: { 1125 MaybeAlign MA = 1126 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(); 1127 Type *StoreTy = CI->getArgOperand(0)->getType(); 1128 Align Alignment = DL.getValueOrABITypeAlignment(MA, 1129 StoreTy->getScalarType()); 1130 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) && 1131 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy), 1132 Alignment)) 1133 return false; 1134 scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT); 1135 return true; 1136 } 1137 case Intrinsic::masked_expandload: 1138 if (TTI.isLegalMaskedExpandLoad( 1139 CI->getType(), 1140 CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne())) 1141 return false; 1142 scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); 1143 return true; 1144 case Intrinsic::masked_compressstore: 1145 if (TTI.isLegalMaskedCompressStore( 1146 CI->getArgOperand(0)->getType(), 1147 CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne())) 1148 return false; 1149 scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU, 1150 ModifiedDT); 1151 return true; 1152 } 1153 } 1154 1155 return false; 1156 } 1157