xref: /llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (revision 25d976b45cb5b3d222d3a9cd94caa8a54031bbb7)
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    intrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Transforms/Scalar.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include <cassert>
39 #include <optional>
40 
41 using namespace llvm;
42 
43 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
44 
45 namespace {
46 
47 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48 public:
49   static char ID; // Pass identification, replacement for typeid
50 
51   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
52     initializeScalarizeMaskedMemIntrinLegacyPassPass(
53         *PassRegistry::getPassRegistry());
54   }
55 
56   bool runOnFunction(Function &F) override;
57 
58   StringRef getPassName() const override {
59     return "Scalarize Masked Memory Intrinsics";
60   }
61 
62   void getAnalysisUsage(AnalysisUsage &AU) const override {
63     AU.addRequired<TargetTransformInfoWrapperPass>();
64     AU.addPreserved<DominatorTreeWrapperPass>();
65   }
66 };
67 
68 } // end anonymous namespace
69 
70 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
71                           const TargetTransformInfo &TTI, const DataLayout &DL,
72                           bool HasBranchDivergence, DomTreeUpdater *DTU);
73 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
74                              const TargetTransformInfo &TTI,
75                              const DataLayout &DL, bool HasBranchDivergence,
76                              DomTreeUpdater *DTU);
77 
78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79 
80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81                       "Scalarize unsupported masked memory intrinsics", false,
82                       false)
83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86                     "Scalarize unsupported masked memory intrinsics", false,
87                     false)
88 
89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90   return new ScalarizeMaskedMemIntrinLegacyPass();
91 }
92 
93 static bool isConstantIntVector(Value *Mask) {
94   Constant *C = dyn_cast<Constant>(Mask);
95   if (!C)
96     return false;
97 
98   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99   for (unsigned i = 0; i != NumElts; ++i) {
100     Constant *CElt = C->getAggregateElement(i);
101     if (!CElt || !isa<ConstantInt>(CElt))
102       return false;
103   }
104 
105   return true;
106 }
107 
108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109                                 unsigned Idx) {
110   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111 }
112 
113 // Translate a masked load intrinsic like
114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115 //                               <16 x i1> %mask, <16 x i32> %passthru)
116 // to a chain of basic blocks, with loading element one-by-one if
117 // the appropriate mask bit is set
118 //
119 //  %1 = bitcast i8* %addr to i32*
120 //  %2 = extractelement <16 x i1> %mask, i32 0
121 //  br i1 %2, label %cond.load, label %else
122 //
123 // cond.load:                                        ; preds = %0
124 //  %3 = getelementptr i32* %1, i32 0
125 //  %4 = load i32* %3
126 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127 //  br label %else
128 //
129 // else:                                             ; preds = %0, %cond.load
130 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131 //  %6 = extractelement <16 x i1> %mask, i32 1
132 //  br i1 %6, label %cond.load1, label %else2
133 //
134 // cond.load1:                                       ; preds = %else
135 //  %7 = getelementptr i32* %1, i32 1
136 //  %8 = load i32* %7
137 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138 //  br label %else2
139 //
140 // else2:                                          ; preds = %else, %cond.load1
141 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142 //  %10 = extractelement <16 x i1> %mask, i32 2
143 //  br i1 %10, label %cond.load4, label %else5
144 //
145 static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146                                 CallInst *CI, DomTreeUpdater *DTU,
147                                 bool &ModifiedDT) {
148   Value *Ptr = CI->getArgOperand(0);
149   Value *Alignment = CI->getArgOperand(1);
150   Value *Mask = CI->getArgOperand(2);
151   Value *Src0 = CI->getArgOperand(3);
152 
153   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
154   VectorType *VecType = cast<FixedVectorType>(CI->getType());
155 
156   Type *EltTy = VecType->getElementType();
157 
158   IRBuilder<> Builder(CI->getContext());
159   Instruction *InsertPt = CI;
160   BasicBlock *IfBlock = CI->getParent();
161 
162   Builder.SetInsertPoint(InsertPt);
163   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
164 
165   // Short-cut if the mask is all-true.
166   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
167     LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
168     NewI->copyMetadata(*CI);
169     NewI->takeName(CI);
170     CI->replaceAllUsesWith(NewI);
171     CI->eraseFromParent();
172     return;
173   }
174 
175   // Adjust alignment for the scalar instruction.
176   const Align AdjustedAlignVal =
177       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
178   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179 
180   // The result vector
181   Value *VResult = Src0;
182 
183   if (isConstantIntVector(Mask)) {
184     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186         continue;
187       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
188       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
189       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190     }
191     CI->replaceAllUsesWith(VResult);
192     CI->eraseFromParent();
193     return;
194   }
195 
196   // Optimize the case where the "masked load" is a predicated load - that is,
197   // where the mask is the splat of a non-constant scalar boolean. In that case,
198   // use that splated value as the guard on a conditional vector load.
199   if (isSplatValue(Mask, /*Index=*/0)) {
200     Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
201                                                     Mask->getName() + ".first");
202     Instruction *ThenTerm =
203         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
204                                   /*BranchWeights=*/nullptr, DTU);
205 
206     BasicBlock *CondBlock = ThenTerm->getParent();
207     CondBlock->setName("cond.load");
208     Builder.SetInsertPoint(CondBlock->getTerminator());
209     LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
210                                                CI->getName() + ".cond.load");
211     Load->copyMetadata(*CI);
212 
213     BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
214     Builder.SetInsertPoint(PostLoad, PostLoad->begin());
215     PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
216     Phi->addIncoming(Load, CondBlock);
217     Phi->addIncoming(Src0, IfBlock);
218     Phi->takeName(CI);
219 
220     CI->replaceAllUsesWith(Phi);
221     CI->eraseFromParent();
222     ModifiedDT = true;
223     return;
224   }
225   // If the mask is not v1i1, use scalar bit test operations. This generates
226   // better results on X86 at least. However, don't do this on GPUs and other
227   // machines with divergence, as there each i1 needs a vector register.
228   Value *SclrMask = nullptr;
229   if (VectorWidth != 1 && !HasBranchDivergence) {
230     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
231     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
232   }
233 
234   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
235     // Fill the "else" block, created in the previous iteration
236     //
237     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238     //  %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239     //  %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240     //
241     // On GPUs, use
242     //  %cond = extrectelement %mask, Idx
243     // instead
244     Value *Predicate;
245     if (SclrMask != nullptr) {
246       Value *Mask = Builder.getInt(APInt::getOneBitSet(
247           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
248       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
249                                        Builder.getIntN(VectorWidth, 0));
250     } else {
251       Predicate = Builder.CreateExtractElement(Mask, Idx);
252     }
253 
254     // Create "cond" block
255     //
256     //  %EltAddr = getelementptr i32* %1, i32 0
257     //  %Elt = load i32* %EltAddr
258     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
259     //
260     Instruction *ThenTerm =
261         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
262                                   /*BranchWeights=*/nullptr, DTU);
263 
264     BasicBlock *CondBlock = ThenTerm->getParent();
265     CondBlock->setName("cond.load");
266 
267     Builder.SetInsertPoint(CondBlock->getTerminator());
268     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
269     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
270     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
271 
272     // Create "else" block, fill it in the next iteration
273     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
274     NewIfBlock->setName("else");
275     BasicBlock *PrevIfBlock = IfBlock;
276     IfBlock = NewIfBlock;
277 
278     // Create the phi to join the new and previous value.
279     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
280     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
281     Phi->addIncoming(NewVResult, CondBlock);
282     Phi->addIncoming(VResult, PrevIfBlock);
283     VResult = Phi;
284   }
285 
286   CI->replaceAllUsesWith(VResult);
287   CI->eraseFromParent();
288 
289   ModifiedDT = true;
290 }
291 
292 // Translate a masked store intrinsic, like
293 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
294 //                               <16 x i1> %mask)
295 // to a chain of basic blocks, that stores element one-by-one if
296 // the appropriate mask bit is set
297 //
298 //   %1 = bitcast i8* %addr to i32*
299 //   %2 = extractelement <16 x i1> %mask, i32 0
300 //   br i1 %2, label %cond.store, label %else
301 //
302 // cond.store:                                       ; preds = %0
303 //   %3 = extractelement <16 x i32> %val, i32 0
304 //   %4 = getelementptr i32* %1, i32 0
305 //   store i32 %3, i32* %4
306 //   br label %else
307 //
308 // else:                                             ; preds = %0, %cond.store
309 //   %5 = extractelement <16 x i1> %mask, i32 1
310 //   br i1 %5, label %cond.store1, label %else2
311 //
312 // cond.store1:                                      ; preds = %else
313 //   %6 = extractelement <16 x i32> %val, i32 1
314 //   %7 = getelementptr i32* %1, i32 1
315 //   store i32 %6, i32* %7
316 //   br label %else2
317 //   . . .
318 static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319                                  CallInst *CI, DomTreeUpdater *DTU,
320                                  bool &ModifiedDT) {
321   Value *Src = CI->getArgOperand(0);
322   Value *Ptr = CI->getArgOperand(1);
323   Value *Alignment = CI->getArgOperand(2);
324   Value *Mask = CI->getArgOperand(3);
325 
326   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327   auto *VecType = cast<VectorType>(Src->getType());
328 
329   Type *EltTy = VecType->getElementType();
330 
331   IRBuilder<> Builder(CI->getContext());
332   Instruction *InsertPt = CI;
333   Builder.SetInsertPoint(InsertPt);
334   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
335 
336   // Short-cut if the mask is all-true.
337   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
338     StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
339     Store->takeName(CI);
340     Store->copyMetadata(*CI);
341     CI->eraseFromParent();
342     return;
343   }
344 
345   // Adjust alignment for the scalar instruction.
346   const Align AdjustedAlignVal =
347       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
348   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
349 
350   if (isConstantIntVector(Mask)) {
351     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
352       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
353         continue;
354       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
355       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
356       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
357     }
358     CI->eraseFromParent();
359     return;
360   }
361 
362   // Optimize the case where the "masked store" is a predicated store - that is,
363   // when the mask is the splat of a non-constant scalar boolean. In that case,
364   // optimize to a conditional store.
365   if (isSplatValue(Mask, /*Index=*/0)) {
366     Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
367                                                     Mask->getName() + ".first");
368     Instruction *ThenTerm =
369         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
370                                   /*BranchWeights=*/nullptr, DTU);
371     BasicBlock *CondBlock = ThenTerm->getParent();
372     CondBlock->setName("cond.store");
373     Builder.SetInsertPoint(CondBlock->getTerminator());
374 
375     StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
376     Store->takeName(CI);
377     Store->copyMetadata(*CI);
378 
379     CI->eraseFromParent();
380     ModifiedDT = true;
381     return;
382   }
383 
384   // If the mask is not v1i1, use scalar bit test operations. This generates
385   // better results on X86 at least. However, don't do this on GPUs or other
386   // machines with branch divergence, as there each i1 takes up a register.
387   Value *SclrMask = nullptr;
388   if (VectorWidth != 1 && !HasBranchDivergence) {
389     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
390     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
391   }
392 
393   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
394     // Fill the "else" block, created in the previous iteration
395     //
396     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
397     //  %cond = icmp ne i16 %mask_1, 0
398     //  br i1 %mask_1, label %cond.store, label %else
399     //
400     // On GPUs, use
401     //  %cond = extrectelement %mask, Idx
402     // instead
403     Value *Predicate;
404     if (SclrMask != nullptr) {
405       Value *Mask = Builder.getInt(APInt::getOneBitSet(
406           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
407       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
408                                        Builder.getIntN(VectorWidth, 0));
409     } else {
410       Predicate = Builder.CreateExtractElement(Mask, Idx);
411     }
412 
413     // Create "cond" block
414     //
415     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
416     //  %EltAddr = getelementptr i32* %1, i32 0
417     //  %store i32 %OneElt, i32* %EltAddr
418     //
419     Instruction *ThenTerm =
420         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
421                                   /*BranchWeights=*/nullptr, DTU);
422 
423     BasicBlock *CondBlock = ThenTerm->getParent();
424     CondBlock->setName("cond.store");
425 
426     Builder.SetInsertPoint(CondBlock->getTerminator());
427     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
428     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
429     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
430 
431     // Create "else" block, fill it in the next iteration
432     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
433     NewIfBlock->setName("else");
434 
435     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
436   }
437   CI->eraseFromParent();
438 
439   ModifiedDT = true;
440 }
441 
442 // Translate a masked gather intrinsic like
443 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
444 //                               <16 x i1> %Mask, <16 x i32> %Src)
445 // to a chain of basic blocks, with loading element one-by-one if
446 // the appropriate mask bit is set
447 //
448 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
449 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
450 // br i1 %Mask0, label %cond.load, label %else
451 //
452 // cond.load:
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // %Load0 = load i32, i32* %Ptr0, align 4
455 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
456 // br label %else
457 //
458 // else:
459 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
460 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
461 // br i1 %Mask1, label %cond.load1, label %else2
462 //
463 // cond.load1:
464 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
465 // %Load1 = load i32, i32* %Ptr1, align 4
466 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
467 // br label %else2
468 // . . .
469 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
470 // ret <16 x i32> %Result
471 static void scalarizeMaskedGather(const DataLayout &DL,
472                                   bool HasBranchDivergence, CallInst *CI,
473                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
474   Value *Ptrs = CI->getArgOperand(0);
475   Value *Alignment = CI->getArgOperand(1);
476   Value *Mask = CI->getArgOperand(2);
477   Value *Src0 = CI->getArgOperand(3);
478 
479   auto *VecType = cast<FixedVectorType>(CI->getType());
480   Type *EltTy = VecType->getElementType();
481 
482   IRBuilder<> Builder(CI->getContext());
483   Instruction *InsertPt = CI;
484   BasicBlock *IfBlock = CI->getParent();
485   Builder.SetInsertPoint(InsertPt);
486   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
487 
488   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
489 
490   // The result vector
491   Value *VResult = Src0;
492   unsigned VectorWidth = VecType->getNumElements();
493 
494   // Shorten the way if the mask is a vector of constants.
495   if (isConstantIntVector(Mask)) {
496     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
497       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
498         continue;
499       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
500       LoadInst *Load =
501           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
502       VResult =
503           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
504     }
505     CI->replaceAllUsesWith(VResult);
506     CI->eraseFromParent();
507     return;
508   }
509 
510   // If the mask is not v1i1, use scalar bit test operations. This generates
511   // better results on X86 at least. However, don't do this on GPUs or other
512   // machines with branch divergence, as there, each i1 takes up a register.
513   Value *SclrMask = nullptr;
514   if (VectorWidth != 1 && !HasBranchDivergence) {
515     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
516     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
517   }
518 
519   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
520     // Fill the "else" block, created in the previous iteration
521     //
522     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
523     //  %cond = icmp ne i16 %mask_1, 0
524     //  br i1 %Mask1, label %cond.load, label %else
525     //
526     // On GPUs, use
527     //  %cond = extrectelement %mask, Idx
528     // instead
529 
530     Value *Predicate;
531     if (SclrMask != nullptr) {
532       Value *Mask = Builder.getInt(APInt::getOneBitSet(
533           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
534       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
535                                        Builder.getIntN(VectorWidth, 0));
536     } else {
537       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
538     }
539 
540     // Create "cond" block
541     //
542     //  %EltAddr = getelementptr i32* %1, i32 0
543     //  %Elt = load i32* %EltAddr
544     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
545     //
546     Instruction *ThenTerm =
547         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
548                                   /*BranchWeights=*/nullptr, DTU);
549 
550     BasicBlock *CondBlock = ThenTerm->getParent();
551     CondBlock->setName("cond.load");
552 
553     Builder.SetInsertPoint(CondBlock->getTerminator());
554     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
555     LoadInst *Load =
556         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
557     Value *NewVResult =
558         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
559 
560     // Create "else" block, fill it in the next iteration
561     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
562     NewIfBlock->setName("else");
563     BasicBlock *PrevIfBlock = IfBlock;
564     IfBlock = NewIfBlock;
565 
566     // Create the phi to join the new and previous value.
567     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
568     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
569     Phi->addIncoming(NewVResult, CondBlock);
570     Phi->addIncoming(VResult, PrevIfBlock);
571     VResult = Phi;
572   }
573 
574   CI->replaceAllUsesWith(VResult);
575   CI->eraseFromParent();
576 
577   ModifiedDT = true;
578 }
579 
580 // Translate a masked scatter intrinsic, like
581 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
582 //                                  <16 x i1> %Mask)
583 // to a chain of basic blocks, that stores element one-by-one if
584 // the appropriate mask bit is set.
585 //
586 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
587 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
588 // br i1 %Mask0, label %cond.store, label %else
589 //
590 // cond.store:
591 // %Elt0 = extractelement <16 x i32> %Src, i32 0
592 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
593 // store i32 %Elt0, i32* %Ptr0, align 4
594 // br label %else
595 //
596 // else:
597 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
598 // br i1 %Mask1, label %cond.store1, label %else2
599 //
600 // cond.store1:
601 // %Elt1 = extractelement <16 x i32> %Src, i32 1
602 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
603 // store i32 %Elt1, i32* %Ptr1, align 4
604 // br label %else2
605 //   . . .
606 static void scalarizeMaskedScatter(const DataLayout &DL,
607                                    bool HasBranchDivergence, CallInst *CI,
608                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
609   Value *Src = CI->getArgOperand(0);
610   Value *Ptrs = CI->getArgOperand(1);
611   Value *Alignment = CI->getArgOperand(2);
612   Value *Mask = CI->getArgOperand(3);
613 
614   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
615 
616   assert(
617       isa<VectorType>(Ptrs->getType()) &&
618       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
619       "Vector of pointers is expected in masked scatter intrinsic");
620 
621   IRBuilder<> Builder(CI->getContext());
622   Instruction *InsertPt = CI;
623   Builder.SetInsertPoint(InsertPt);
624   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
625 
626   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627   unsigned VectorWidth = SrcFVTy->getNumElements();
628 
629   // Shorten the way if the mask is a vector of constants.
630   if (isConstantIntVector(Mask)) {
631     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
633         continue;
634       Value *OneElt =
635           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
636       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
637       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
638     }
639     CI->eraseFromParent();
640     return;
641   }
642 
643   // If the mask is not v1i1, use scalar bit test operations. This generates
644   // better results on X86 at least.
645   Value *SclrMask = nullptr;
646   if (VectorWidth != 1 && !HasBranchDivergence) {
647     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
648     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
649   }
650 
651   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
652     // Fill the "else" block, created in the previous iteration
653     //
654     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
655     //  %cond = icmp ne i16 %mask_1, 0
656     //  br i1 %Mask1, label %cond.store, label %else
657     //
658     // On GPUs, use
659     //  %cond = extrectelement %mask, Idx
660     // instead
661     Value *Predicate;
662     if (SclrMask != nullptr) {
663       Value *Mask = Builder.getInt(APInt::getOneBitSet(
664           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
665       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
666                                        Builder.getIntN(VectorWidth, 0));
667     } else {
668       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
669     }
670 
671     // Create "cond" block
672     //
673     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
674     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
675     //  %store i32 %Elt1, i32* %Ptr1
676     //
677     Instruction *ThenTerm =
678         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
679                                   /*BranchWeights=*/nullptr, DTU);
680 
681     BasicBlock *CondBlock = ThenTerm->getParent();
682     CondBlock->setName("cond.store");
683 
684     Builder.SetInsertPoint(CondBlock->getTerminator());
685     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
686     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
687     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
688 
689     // Create "else" block, fill it in the next iteration
690     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
691     NewIfBlock->setName("else");
692 
693     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
694   }
695   CI->eraseFromParent();
696 
697   ModifiedDT = true;
698 }
699 
700 static void scalarizeMaskedExpandLoad(const DataLayout &DL,
701                                       bool HasBranchDivergence, CallInst *CI,
702                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
703   Value *Ptr = CI->getArgOperand(0);
704   Value *Mask = CI->getArgOperand(1);
705   Value *PassThru = CI->getArgOperand(2);
706   Align Alignment = CI->getParamAlign(0).valueOrOne();
707 
708   auto *VecType = cast<FixedVectorType>(CI->getType());
709 
710   Type *EltTy = VecType->getElementType();
711 
712   IRBuilder<> Builder(CI->getContext());
713   Instruction *InsertPt = CI;
714   BasicBlock *IfBlock = CI->getParent();
715 
716   Builder.SetInsertPoint(InsertPt);
717   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
718 
719   unsigned VectorWidth = VecType->getNumElements();
720 
721   // The result vector
722   Value *VResult = PassThru;
723 
724   // Adjust alignment for the scalar instruction.
725   const Align AdjustedAlignment =
726       commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
727 
728   // Shorten the way if the mask is a vector of constants.
729   // Create a build_vector pattern, with loads/poisons as necessary and then
730   // shuffle blend with the pass through value.
731   if (isConstantIntVector(Mask)) {
732     unsigned MemIndex = 0;
733     VResult = PoisonValue::get(VecType);
734     SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
735     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
736       Value *InsertElt;
737       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
738         InsertElt = PoisonValue::get(EltTy);
739         ShuffleMask[Idx] = Idx + VectorWidth;
740       } else {
741         Value *NewPtr =
742             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
743         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
744                                               "Load" + Twine(Idx));
745         ShuffleMask[Idx] = Idx;
746         ++MemIndex;
747       }
748       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
749                                             "Res" + Twine(Idx));
750     }
751     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
752     CI->replaceAllUsesWith(VResult);
753     CI->eraseFromParent();
754     return;
755   }
756 
757   // If the mask is not v1i1, use scalar bit test operations. This generates
758   // better results on X86 at least. However, don't do this on GPUs or other
759   // machines with branch divergence, as there, each i1 takes up a register.
760   Value *SclrMask = nullptr;
761   if (VectorWidth != 1 && !HasBranchDivergence) {
762     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
763     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
764   }
765 
766   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
767     // Fill the "else" block, created in the previous iteration
768     //
769     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770     //  %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771     //  label %cond.load, label %else
772     //
773     // On GPUs, use
774     //  %cond = extrectelement %mask, Idx
775     // instead
776 
777     Value *Predicate;
778     if (SclrMask != nullptr) {
779       Value *Mask = Builder.getInt(APInt::getOneBitSet(
780           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
781       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
782                                        Builder.getIntN(VectorWidth, 0));
783     } else {
784       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
785     }
786 
787     // Create "cond" block
788     //
789     //  %EltAddr = getelementptr i32* %1, i32 0
790     //  %Elt = load i32* %EltAddr
791     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
792     //
793     Instruction *ThenTerm =
794         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
795                                   /*BranchWeights=*/nullptr, DTU);
796 
797     BasicBlock *CondBlock = ThenTerm->getParent();
798     CondBlock->setName("cond.load");
799 
800     Builder.SetInsertPoint(CondBlock->getTerminator());
801     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
802     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
803 
804     // Move the pointer if there are more blocks to come.
805     Value *NewPtr;
806     if ((Idx + 1) != VectorWidth)
807       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
808 
809     // Create "else" block, fill it in the next iteration
810     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
811     NewIfBlock->setName("else");
812     BasicBlock *PrevIfBlock = IfBlock;
813     IfBlock = NewIfBlock;
814 
815     // Create the phi to join the new and previous value.
816     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
817     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
818     ResultPhi->addIncoming(NewVResult, CondBlock);
819     ResultPhi->addIncoming(VResult, PrevIfBlock);
820     VResult = ResultPhi;
821 
822     // Add a PHI for the pointer if this isn't the last iteration.
823     if ((Idx + 1) != VectorWidth) {
824       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
825       PtrPhi->addIncoming(NewPtr, CondBlock);
826       PtrPhi->addIncoming(Ptr, PrevIfBlock);
827       Ptr = PtrPhi;
828     }
829   }
830 
831   CI->replaceAllUsesWith(VResult);
832   CI->eraseFromParent();
833 
834   ModifiedDT = true;
835 }
836 
837 static void scalarizeMaskedCompressStore(const DataLayout &DL,
838                                          bool HasBranchDivergence, CallInst *CI,
839                                          DomTreeUpdater *DTU,
840                                          bool &ModifiedDT) {
841   Value *Src = CI->getArgOperand(0);
842   Value *Ptr = CI->getArgOperand(1);
843   Value *Mask = CI->getArgOperand(2);
844   Align Alignment = CI->getParamAlign(1).valueOrOne();
845 
846   auto *VecType = cast<FixedVectorType>(Src->getType());
847 
848   IRBuilder<> Builder(CI->getContext());
849   Instruction *InsertPt = CI;
850   BasicBlock *IfBlock = CI->getParent();
851 
852   Builder.SetInsertPoint(InsertPt);
853   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
854 
855   Type *EltTy = VecType->getElementType();
856 
857   // Adjust alignment for the scalar instruction.
858   const Align AdjustedAlignment =
859       commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
860 
861   unsigned VectorWidth = VecType->getNumElements();
862 
863   // Shorten the way if the mask is a vector of constants.
864   if (isConstantIntVector(Mask)) {
865     unsigned MemIndex = 0;
866     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
867       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
868         continue;
869       Value *OneElt =
870           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
871       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
872       Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
873       ++MemIndex;
874     }
875     CI->eraseFromParent();
876     return;
877   }
878 
879   // If the mask is not v1i1, use scalar bit test operations. This generates
880   // better results on X86 at least. However, don't do this on GPUs or other
881   // machines with branch divergence, as there, each i1 takes up a register.
882   Value *SclrMask = nullptr;
883   if (VectorWidth != 1 && !HasBranchDivergence) {
884     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
885     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
886   }
887 
888   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
889     // Fill the "else" block, created in the previous iteration
890     //
891     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
892     //  br i1 %mask_1, label %cond.store, label %else
893     //
894     // On GPUs, use
895     //  %cond = extrectelement %mask, Idx
896     // instead
897     Value *Predicate;
898     if (SclrMask != nullptr) {
899       Value *Mask = Builder.getInt(APInt::getOneBitSet(
900           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
901       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
902                                        Builder.getIntN(VectorWidth, 0));
903     } else {
904       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
905     }
906 
907     // Create "cond" block
908     //
909     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
910     //  %EltAddr = getelementptr i32* %1, i32 0
911     //  %store i32 %OneElt, i32* %EltAddr
912     //
913     Instruction *ThenTerm =
914         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
915                                   /*BranchWeights=*/nullptr, DTU);
916 
917     BasicBlock *CondBlock = ThenTerm->getParent();
918     CondBlock->setName("cond.store");
919 
920     Builder.SetInsertPoint(CondBlock->getTerminator());
921     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
922     Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
923 
924     // Move the pointer if there are more blocks to come.
925     Value *NewPtr;
926     if ((Idx + 1) != VectorWidth)
927       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
928 
929     // Create "else" block, fill it in the next iteration
930     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
931     NewIfBlock->setName("else");
932     BasicBlock *PrevIfBlock = IfBlock;
933     IfBlock = NewIfBlock;
934 
935     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
936 
937     // Add a PHI for the pointer if this isn't the last iteration.
938     if ((Idx + 1) != VectorWidth) {
939       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
940       PtrPhi->addIncoming(NewPtr, CondBlock);
941       PtrPhi->addIncoming(Ptr, PrevIfBlock);
942       Ptr = PtrPhi;
943     }
944   }
945   CI->eraseFromParent();
946 
947   ModifiedDT = true;
948 }
949 
950 static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI,
951                                            DomTreeUpdater *DTU,
952                                            bool &ModifiedDT) {
953   // If we extend histogram to return a result someday (like the updated vector)
954   // then we'll need to support it here.
955   assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
956   Value *Ptrs = CI->getArgOperand(0);
957   Value *Inc = CI->getArgOperand(1);
958   Value *Mask = CI->getArgOperand(2);
959 
960   auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
961   Type *EltTy = Inc->getType();
962 
963   IRBuilder<> Builder(CI->getContext());
964   Instruction *InsertPt = CI;
965   Builder.SetInsertPoint(InsertPt);
966 
967   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
968 
969   // FIXME: Do we need to add an alignment parameter to the intrinsic?
970   unsigned VectorWidth = AddrType->getNumElements();
971 
972   // Shorten the way if the mask is a vector of constants.
973   if (isConstantIntVector(Mask)) {
974     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
975       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
976         continue;
977       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
978       LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
979       Value *Add = Builder.CreateAdd(Load, Inc);
980       Builder.CreateStore(Add, Ptr);
981     }
982     CI->eraseFromParent();
983     return;
984   }
985 
986   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
987     Value *Predicate =
988         Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
989 
990     Instruction *ThenTerm =
991         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
992                                   /*BranchWeights=*/nullptr, DTU);
993 
994     BasicBlock *CondBlock = ThenTerm->getParent();
995     CondBlock->setName("cond.histogram.update");
996 
997     Builder.SetInsertPoint(CondBlock->getTerminator());
998     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
999     LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1000     Value *Add = Builder.CreateAdd(Load, Inc);
1001     Builder.CreateStore(Add, Ptr);
1002 
1003     // Create "else" block, fill it in the next iteration
1004     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1005     NewIfBlock->setName("else");
1006     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1007   }
1008 
1009   CI->eraseFromParent();
1010   ModifiedDT = true;
1011 }
1012 
1013 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
1014                     DominatorTree *DT) {
1015   std::optional<DomTreeUpdater> DTU;
1016   if (DT)
1017     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1018 
1019   bool EverMadeChange = false;
1020   bool MadeChange = true;
1021   auto &DL = F.getDataLayout();
1022   bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1023   while (MadeChange) {
1024     MadeChange = false;
1025     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
1026       bool ModifiedDTOnIteration = false;
1027       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1028                                   HasBranchDivergence, DTU ? &*DTU : nullptr);
1029 
1030       // Restart BB iteration if the dominator tree of the Function was changed
1031       if (ModifiedDTOnIteration)
1032         break;
1033     }
1034 
1035     EverMadeChange |= MadeChange;
1036   }
1037   return EverMadeChange;
1038 }
1039 
1040 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1041   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1042   DominatorTree *DT = nullptr;
1043   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1044     DT = &DTWP->getDomTree();
1045   return runImpl(F, TTI, DT);
1046 }
1047 
1048 PreservedAnalyses
1049 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
1050   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1051   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
1052   if (!runImpl(F, TTI, DT))
1053     return PreservedAnalyses::all();
1054   PreservedAnalyses PA;
1055   PA.preserve<TargetIRAnalysis>();
1056   PA.preserve<DominatorTreeAnalysis>();
1057   return PA;
1058 }
1059 
1060 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1061                           const TargetTransformInfo &TTI, const DataLayout &DL,
1062                           bool HasBranchDivergence, DomTreeUpdater *DTU) {
1063   bool MadeChange = false;
1064 
1065   BasicBlock::iterator CurInstIterator = BB.begin();
1066   while (CurInstIterator != BB.end()) {
1067     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1068       MadeChange |=
1069           optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1070     if (ModifiedDT)
1071       return true;
1072   }
1073 
1074   return MadeChange;
1075 }
1076 
1077 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1078                              const TargetTransformInfo &TTI,
1079                              const DataLayout &DL, bool HasBranchDivergence,
1080                              DomTreeUpdater *DTU) {
1081   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1082   if (II) {
1083     // The scalarization code below does not work for scalable vectors.
1084     if (isa<ScalableVectorType>(II->getType()) ||
1085         any_of(II->args(),
1086                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1087       return false;
1088     switch (II->getIntrinsicID()) {
1089     default:
1090       break;
1091     case Intrinsic::experimental_vector_histogram_add:
1092       if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(),
1093                                            CI->getArgOperand(1)->getType()))
1094         return false;
1095       scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1096       return true;
1097     case Intrinsic::masked_load:
1098       // Scalarize unsupported vector masked load
1099       if (TTI.isLegalMaskedLoad(
1100               CI->getType(),
1101               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
1102         return false;
1103       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1104       return true;
1105     case Intrinsic::masked_store:
1106       if (TTI.isLegalMaskedStore(
1107               CI->getArgOperand(0)->getType(),
1108               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
1109         return false;
1110       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1111       return true;
1112     case Intrinsic::masked_gather: {
1113       MaybeAlign MA =
1114           cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
1115       Type *LoadTy = CI->getType();
1116       Align Alignment = DL.getValueOrABITypeAlignment(MA,
1117                                                       LoadTy->getScalarType());
1118       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1119           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1120         return false;
1121       scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1122       return true;
1123     }
1124     case Intrinsic::masked_scatter: {
1125       MaybeAlign MA =
1126           cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
1127       Type *StoreTy = CI->getArgOperand(0)->getType();
1128       Align Alignment = DL.getValueOrABITypeAlignment(MA,
1129                                                       StoreTy->getScalarType());
1130       if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1131           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1132                                            Alignment))
1133         return false;
1134       scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1135       return true;
1136     }
1137     case Intrinsic::masked_expandload:
1138       if (TTI.isLegalMaskedExpandLoad(
1139               CI->getType(),
1140               CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
1141         return false;
1142       scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1143       return true;
1144     case Intrinsic::masked_compressstore:
1145       if (TTI.isLegalMaskedCompressStore(
1146               CI->getArgOperand(0)->getType(),
1147               CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
1148         return false;
1149       scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1150                                    ModifiedDT);
1151       return true;
1152     }
1153   }
1154 
1155   return false;
1156 }
1157