xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (revision 349cc55c9796c4596a5b9904cd3281af295f878f)
1e8d8bef9SDimitry Andric //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2e8d8bef9SDimitry Andric //                                    instrinsics
3e8d8bef9SDimitry Andric //
4e8d8bef9SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5e8d8bef9SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
6e8d8bef9SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7e8d8bef9SDimitry Andric //
8e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
9e8d8bef9SDimitry Andric //
10e8d8bef9SDimitry Andric // This pass replaces masked memory intrinsics - when unsupported by the target
11e8d8bef9SDimitry Andric // - with a chain of basic blocks, that deal with the elements one-by-one if the
12e8d8bef9SDimitry Andric // appropriate mask bit is set.
13e8d8bef9SDimitry Andric //
14e8d8bef9SDimitry Andric //===----------------------------------------------------------------------===//
15e8d8bef9SDimitry Andric 
16e8d8bef9SDimitry Andric #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17e8d8bef9SDimitry Andric #include "llvm/ADT/Twine.h"
18fe6060f1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
19e8d8bef9SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
20e8d8bef9SDimitry Andric #include "llvm/IR/BasicBlock.h"
21e8d8bef9SDimitry Andric #include "llvm/IR/Constant.h"
22e8d8bef9SDimitry Andric #include "llvm/IR/Constants.h"
23e8d8bef9SDimitry Andric #include "llvm/IR/DerivedTypes.h"
24fe6060f1SDimitry Andric #include "llvm/IR/Dominators.h"
25e8d8bef9SDimitry Andric #include "llvm/IR/Function.h"
26e8d8bef9SDimitry Andric #include "llvm/IR/IRBuilder.h"
27e8d8bef9SDimitry Andric #include "llvm/IR/InstrTypes.h"
28e8d8bef9SDimitry Andric #include "llvm/IR/Instruction.h"
29e8d8bef9SDimitry Andric #include "llvm/IR/Instructions.h"
30e8d8bef9SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
31e8d8bef9SDimitry Andric #include "llvm/IR/Intrinsics.h"
32e8d8bef9SDimitry Andric #include "llvm/IR/Type.h"
33e8d8bef9SDimitry Andric #include "llvm/IR/Value.h"
34e8d8bef9SDimitry Andric #include "llvm/InitializePasses.h"
35e8d8bef9SDimitry Andric #include "llvm/Pass.h"
36e8d8bef9SDimitry Andric #include "llvm/Support/Casting.h"
37e8d8bef9SDimitry Andric #include "llvm/Transforms/Scalar.h"
38fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39e8d8bef9SDimitry Andric #include <algorithm>
40e8d8bef9SDimitry Andric #include <cassert>
41e8d8bef9SDimitry Andric 
42e8d8bef9SDimitry Andric using namespace llvm;
43e8d8bef9SDimitry Andric 
44e8d8bef9SDimitry Andric #define DEBUG_TYPE "scalarize-masked-mem-intrin"
45e8d8bef9SDimitry Andric 
46e8d8bef9SDimitry Andric namespace {
47e8d8bef9SDimitry Andric 
48e8d8bef9SDimitry Andric class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
49e8d8bef9SDimitry Andric public:
50e8d8bef9SDimitry Andric   static char ID; // Pass identification, replacement for typeid
51e8d8bef9SDimitry Andric 
52e8d8bef9SDimitry Andric   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
53e8d8bef9SDimitry Andric     initializeScalarizeMaskedMemIntrinLegacyPassPass(
54e8d8bef9SDimitry Andric         *PassRegistry::getPassRegistry());
55e8d8bef9SDimitry Andric   }
56e8d8bef9SDimitry Andric 
57e8d8bef9SDimitry Andric   bool runOnFunction(Function &F) override;
58e8d8bef9SDimitry Andric 
59e8d8bef9SDimitry Andric   StringRef getPassName() const override {
60e8d8bef9SDimitry Andric     return "Scalarize Masked Memory Intrinsics";
61e8d8bef9SDimitry Andric   }
62e8d8bef9SDimitry Andric 
63e8d8bef9SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
64e8d8bef9SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
65fe6060f1SDimitry Andric     AU.addPreserved<DominatorTreeWrapperPass>();
66e8d8bef9SDimitry Andric   }
67e8d8bef9SDimitry Andric };
68e8d8bef9SDimitry Andric 
69e8d8bef9SDimitry Andric } // end anonymous namespace
70e8d8bef9SDimitry Andric 
71e8d8bef9SDimitry Andric static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
72fe6060f1SDimitry Andric                           const TargetTransformInfo &TTI, const DataLayout &DL,
73fe6060f1SDimitry Andric                           DomTreeUpdater *DTU);
74e8d8bef9SDimitry Andric static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75e8d8bef9SDimitry Andric                              const TargetTransformInfo &TTI,
76fe6060f1SDimitry Andric                              const DataLayout &DL, DomTreeUpdater *DTU);
77e8d8bef9SDimitry Andric 
78e8d8bef9SDimitry Andric char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79e8d8bef9SDimitry Andric 
80e8d8bef9SDimitry Andric INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81e8d8bef9SDimitry Andric                       "Scalarize unsupported masked memory intrinsics", false,
82e8d8bef9SDimitry Andric                       false)
83e8d8bef9SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
84fe6060f1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
85e8d8bef9SDimitry Andric INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86e8d8bef9SDimitry Andric                     "Scalarize unsupported masked memory intrinsics", false,
87e8d8bef9SDimitry Andric                     false)
88e8d8bef9SDimitry Andric 
89e8d8bef9SDimitry Andric FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90e8d8bef9SDimitry Andric   return new ScalarizeMaskedMemIntrinLegacyPass();
91e8d8bef9SDimitry Andric }
92e8d8bef9SDimitry Andric 
93e8d8bef9SDimitry Andric static bool isConstantIntVector(Value *Mask) {
94e8d8bef9SDimitry Andric   Constant *C = dyn_cast<Constant>(Mask);
95e8d8bef9SDimitry Andric   if (!C)
96e8d8bef9SDimitry Andric     return false;
97e8d8bef9SDimitry Andric 
98e8d8bef9SDimitry Andric   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99e8d8bef9SDimitry Andric   for (unsigned i = 0; i != NumElts; ++i) {
100e8d8bef9SDimitry Andric     Constant *CElt = C->getAggregateElement(i);
101e8d8bef9SDimitry Andric     if (!CElt || !isa<ConstantInt>(CElt))
102e8d8bef9SDimitry Andric       return false;
103e8d8bef9SDimitry Andric   }
104e8d8bef9SDimitry Andric 
105e8d8bef9SDimitry Andric   return true;
106e8d8bef9SDimitry Andric }
107e8d8bef9SDimitry Andric 
108fe6060f1SDimitry Andric static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109fe6060f1SDimitry Andric                                 unsigned Idx) {
110fe6060f1SDimitry Andric   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111fe6060f1SDimitry Andric }
112fe6060f1SDimitry Andric 
113e8d8bef9SDimitry Andric // Translate a masked load intrinsic like
114e8d8bef9SDimitry Andric // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115e8d8bef9SDimitry Andric //                               <16 x i1> %mask, <16 x i32> %passthru)
116e8d8bef9SDimitry Andric // to a chain of basic blocks, with loading element one-by-one if
117e8d8bef9SDimitry Andric // the appropriate mask bit is set
118e8d8bef9SDimitry Andric //
119e8d8bef9SDimitry Andric //  %1 = bitcast i8* %addr to i32*
120e8d8bef9SDimitry Andric //  %2 = extractelement <16 x i1> %mask, i32 0
121e8d8bef9SDimitry Andric //  br i1 %2, label %cond.load, label %else
122e8d8bef9SDimitry Andric //
123e8d8bef9SDimitry Andric // cond.load:                                        ; preds = %0
124e8d8bef9SDimitry Andric //  %3 = getelementptr i32* %1, i32 0
125e8d8bef9SDimitry Andric //  %4 = load i32* %3
126e8d8bef9SDimitry Andric //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127e8d8bef9SDimitry Andric //  br label %else
128e8d8bef9SDimitry Andric //
129e8d8bef9SDimitry Andric // else:                                             ; preds = %0, %cond.load
130e8d8bef9SDimitry Andric //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
131e8d8bef9SDimitry Andric //  %6 = extractelement <16 x i1> %mask, i32 1
132e8d8bef9SDimitry Andric //  br i1 %6, label %cond.load1, label %else2
133e8d8bef9SDimitry Andric //
134e8d8bef9SDimitry Andric // cond.load1:                                       ; preds = %else
135e8d8bef9SDimitry Andric //  %7 = getelementptr i32* %1, i32 1
136e8d8bef9SDimitry Andric //  %8 = load i32* %7
137e8d8bef9SDimitry Andric //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138e8d8bef9SDimitry Andric //  br label %else2
139e8d8bef9SDimitry Andric //
140e8d8bef9SDimitry Andric // else2:                                          ; preds = %else, %cond.load1
141e8d8bef9SDimitry Andric //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142e8d8bef9SDimitry Andric //  %10 = extractelement <16 x i1> %mask, i32 2
143e8d8bef9SDimitry Andric //  br i1 %10, label %cond.load4, label %else5
144e8d8bef9SDimitry Andric //
145fe6060f1SDimitry Andric static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
146fe6060f1SDimitry Andric                                 DomTreeUpdater *DTU, bool &ModifiedDT) {
147e8d8bef9SDimitry Andric   Value *Ptr = CI->getArgOperand(0);
148e8d8bef9SDimitry Andric   Value *Alignment = CI->getArgOperand(1);
149e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(2);
150e8d8bef9SDimitry Andric   Value *Src0 = CI->getArgOperand(3);
151e8d8bef9SDimitry Andric 
152e8d8bef9SDimitry Andric   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153e8d8bef9SDimitry Andric   VectorType *VecType = cast<FixedVectorType>(CI->getType());
154e8d8bef9SDimitry Andric 
155e8d8bef9SDimitry Andric   Type *EltTy = VecType->getElementType();
156e8d8bef9SDimitry Andric 
157e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
158e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
159e8d8bef9SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
160e8d8bef9SDimitry Andric 
161e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
162e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
163e8d8bef9SDimitry Andric 
164e8d8bef9SDimitry Andric   // Short-cut if the mask is all-true.
165e8d8bef9SDimitry Andric   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
166e8d8bef9SDimitry Andric     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
167e8d8bef9SDimitry Andric     CI->replaceAllUsesWith(NewI);
168e8d8bef9SDimitry Andric     CI->eraseFromParent();
169e8d8bef9SDimitry Andric     return;
170e8d8bef9SDimitry Andric   }
171e8d8bef9SDimitry Andric 
172e8d8bef9SDimitry Andric   // Adjust alignment for the scalar instruction.
173e8d8bef9SDimitry Andric   const Align AdjustedAlignVal =
174e8d8bef9SDimitry Andric       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
175e8d8bef9SDimitry Andric   // Bitcast %addr from i8* to EltTy*
176e8d8bef9SDimitry Andric   Type *NewPtrType =
177e8d8bef9SDimitry Andric       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
178e8d8bef9SDimitry Andric   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
179e8d8bef9SDimitry Andric   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
180e8d8bef9SDimitry Andric 
181e8d8bef9SDimitry Andric   // The result vector
182e8d8bef9SDimitry Andric   Value *VResult = Src0;
183e8d8bef9SDimitry Andric 
184e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
185e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
187e8d8bef9SDimitry Andric         continue;
188e8d8bef9SDimitry Andric       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
189e8d8bef9SDimitry Andric       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
190e8d8bef9SDimitry Andric       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
191e8d8bef9SDimitry Andric     }
192e8d8bef9SDimitry Andric     CI->replaceAllUsesWith(VResult);
193e8d8bef9SDimitry Andric     CI->eraseFromParent();
194e8d8bef9SDimitry Andric     return;
195e8d8bef9SDimitry Andric   }
196e8d8bef9SDimitry Andric 
197e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
198e8d8bef9SDimitry Andric   // better results on X86 at least.
199e8d8bef9SDimitry Andric   Value *SclrMask;
200e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
201e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
202e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
203e8d8bef9SDimitry Andric   }
204e8d8bef9SDimitry Andric 
205e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
206e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
207e8d8bef9SDimitry Andric     //
208e8d8bef9SDimitry Andric     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
209e8d8bef9SDimitry Andric     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
210e8d8bef9SDimitry Andric     //  %cond = icmp ne i16 %mask_1, 0
211e8d8bef9SDimitry Andric     //  br i1 %mask_1, label %cond.load, label %else
212e8d8bef9SDimitry Andric     //
213e8d8bef9SDimitry Andric     Value *Predicate;
214e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
215fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
216fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
217e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
218e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
219e8d8bef9SDimitry Andric     } else {
220e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx);
221e8d8bef9SDimitry Andric     }
222e8d8bef9SDimitry Andric 
223e8d8bef9SDimitry Andric     // Create "cond" block
224e8d8bef9SDimitry Andric     //
225e8d8bef9SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
226e8d8bef9SDimitry Andric     //  %Elt = load i32* %EltAddr
227e8d8bef9SDimitry Andric     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
228e8d8bef9SDimitry Andric     //
229fe6060f1SDimitry Andric     Instruction *ThenTerm =
230fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
231fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
232e8d8bef9SDimitry Andric 
233fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
234fe6060f1SDimitry Andric     CondBlock->setName("cond.load");
235fe6060f1SDimitry Andric 
236fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
237e8d8bef9SDimitry Andric     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
238e8d8bef9SDimitry Andric     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
239e8d8bef9SDimitry Andric     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
240e8d8bef9SDimitry Andric 
241e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
242fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
243fe6060f1SDimitry Andric     NewIfBlock->setName("else");
244e8d8bef9SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
245e8d8bef9SDimitry Andric     IfBlock = NewIfBlock;
246e8d8bef9SDimitry Andric 
247e8d8bef9SDimitry Andric     // Create the phi to join the new and previous value.
248fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
249e8d8bef9SDimitry Andric     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
250e8d8bef9SDimitry Andric     Phi->addIncoming(NewVResult, CondBlock);
251e8d8bef9SDimitry Andric     Phi->addIncoming(VResult, PrevIfBlock);
252e8d8bef9SDimitry Andric     VResult = Phi;
253e8d8bef9SDimitry Andric   }
254e8d8bef9SDimitry Andric 
255e8d8bef9SDimitry Andric   CI->replaceAllUsesWith(VResult);
256e8d8bef9SDimitry Andric   CI->eraseFromParent();
257e8d8bef9SDimitry Andric 
258e8d8bef9SDimitry Andric   ModifiedDT = true;
259e8d8bef9SDimitry Andric }
260e8d8bef9SDimitry Andric 
261e8d8bef9SDimitry Andric // Translate a masked store intrinsic, like
262e8d8bef9SDimitry Andric // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
263e8d8bef9SDimitry Andric //                               <16 x i1> %mask)
264e8d8bef9SDimitry Andric // to a chain of basic blocks, that stores element one-by-one if
265e8d8bef9SDimitry Andric // the appropriate mask bit is set
266e8d8bef9SDimitry Andric //
267e8d8bef9SDimitry Andric //   %1 = bitcast i8* %addr to i32*
268e8d8bef9SDimitry Andric //   %2 = extractelement <16 x i1> %mask, i32 0
269e8d8bef9SDimitry Andric //   br i1 %2, label %cond.store, label %else
270e8d8bef9SDimitry Andric //
271e8d8bef9SDimitry Andric // cond.store:                                       ; preds = %0
272e8d8bef9SDimitry Andric //   %3 = extractelement <16 x i32> %val, i32 0
273e8d8bef9SDimitry Andric //   %4 = getelementptr i32* %1, i32 0
274e8d8bef9SDimitry Andric //   store i32 %3, i32* %4
275e8d8bef9SDimitry Andric //   br label %else
276e8d8bef9SDimitry Andric //
277e8d8bef9SDimitry Andric // else:                                             ; preds = %0, %cond.store
278e8d8bef9SDimitry Andric //   %5 = extractelement <16 x i1> %mask, i32 1
279e8d8bef9SDimitry Andric //   br i1 %5, label %cond.store1, label %else2
280e8d8bef9SDimitry Andric //
281e8d8bef9SDimitry Andric // cond.store1:                                      ; preds = %else
282e8d8bef9SDimitry Andric //   %6 = extractelement <16 x i32> %val, i32 1
283e8d8bef9SDimitry Andric //   %7 = getelementptr i32* %1, i32 1
284e8d8bef9SDimitry Andric //   store i32 %6, i32* %7
285e8d8bef9SDimitry Andric //   br label %else2
286e8d8bef9SDimitry Andric //   . . .
287fe6060f1SDimitry Andric static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
288fe6060f1SDimitry Andric                                  DomTreeUpdater *DTU, bool &ModifiedDT) {
289e8d8bef9SDimitry Andric   Value *Src = CI->getArgOperand(0);
290e8d8bef9SDimitry Andric   Value *Ptr = CI->getArgOperand(1);
291e8d8bef9SDimitry Andric   Value *Alignment = CI->getArgOperand(2);
292e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(3);
293e8d8bef9SDimitry Andric 
294e8d8bef9SDimitry Andric   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
295e8d8bef9SDimitry Andric   auto *VecType = cast<VectorType>(Src->getType());
296e8d8bef9SDimitry Andric 
297e8d8bef9SDimitry Andric   Type *EltTy = VecType->getElementType();
298e8d8bef9SDimitry Andric 
299e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
300e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
301e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
302e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
303e8d8bef9SDimitry Andric 
304e8d8bef9SDimitry Andric   // Short-cut if the mask is all-true.
305e8d8bef9SDimitry Andric   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
306e8d8bef9SDimitry Andric     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
307e8d8bef9SDimitry Andric     CI->eraseFromParent();
308e8d8bef9SDimitry Andric     return;
309e8d8bef9SDimitry Andric   }
310e8d8bef9SDimitry Andric 
311e8d8bef9SDimitry Andric   // Adjust alignment for the scalar instruction.
312e8d8bef9SDimitry Andric   const Align AdjustedAlignVal =
313e8d8bef9SDimitry Andric       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
314e8d8bef9SDimitry Andric   // Bitcast %addr from i8* to EltTy*
315e8d8bef9SDimitry Andric   Type *NewPtrType =
316e8d8bef9SDimitry Andric       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
317e8d8bef9SDimitry Andric   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
318e8d8bef9SDimitry Andric   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
319e8d8bef9SDimitry Andric 
320e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
321e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
322e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
323e8d8bef9SDimitry Andric         continue;
324e8d8bef9SDimitry Andric       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
325e8d8bef9SDimitry Andric       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
326e8d8bef9SDimitry Andric       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
327e8d8bef9SDimitry Andric     }
328e8d8bef9SDimitry Andric     CI->eraseFromParent();
329e8d8bef9SDimitry Andric     return;
330e8d8bef9SDimitry Andric   }
331e8d8bef9SDimitry Andric 
332e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
333e8d8bef9SDimitry Andric   // better results on X86 at least.
334e8d8bef9SDimitry Andric   Value *SclrMask;
335e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
336e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
337e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
338e8d8bef9SDimitry Andric   }
339e8d8bef9SDimitry Andric 
340e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
341e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
342e8d8bef9SDimitry Andric     //
343e8d8bef9SDimitry Andric     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
344e8d8bef9SDimitry Andric     //  %cond = icmp ne i16 %mask_1, 0
345e8d8bef9SDimitry Andric     //  br i1 %mask_1, label %cond.store, label %else
346e8d8bef9SDimitry Andric     //
347e8d8bef9SDimitry Andric     Value *Predicate;
348e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
349fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
350fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
351e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
352e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
353e8d8bef9SDimitry Andric     } else {
354e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx);
355e8d8bef9SDimitry Andric     }
356e8d8bef9SDimitry Andric 
357e8d8bef9SDimitry Andric     // Create "cond" block
358e8d8bef9SDimitry Andric     //
359e8d8bef9SDimitry Andric     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
360e8d8bef9SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
361e8d8bef9SDimitry Andric     //  %store i32 %OneElt, i32* %EltAddr
362e8d8bef9SDimitry Andric     //
363fe6060f1SDimitry Andric     Instruction *ThenTerm =
364fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
365fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
366e8d8bef9SDimitry Andric 
367fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
368fe6060f1SDimitry Andric     CondBlock->setName("cond.store");
369fe6060f1SDimitry Andric 
370fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
371e8d8bef9SDimitry Andric     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
372e8d8bef9SDimitry Andric     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
373e8d8bef9SDimitry Andric     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
374e8d8bef9SDimitry Andric 
375e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
376fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
377fe6060f1SDimitry Andric     NewIfBlock->setName("else");
378fe6060f1SDimitry Andric 
379fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
380e8d8bef9SDimitry Andric   }
381e8d8bef9SDimitry Andric   CI->eraseFromParent();
382e8d8bef9SDimitry Andric 
383e8d8bef9SDimitry Andric   ModifiedDT = true;
384e8d8bef9SDimitry Andric }
385e8d8bef9SDimitry Andric 
386e8d8bef9SDimitry Andric // Translate a masked gather intrinsic like
387e8d8bef9SDimitry Andric // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
388e8d8bef9SDimitry Andric //                               <16 x i1> %Mask, <16 x i32> %Src)
389e8d8bef9SDimitry Andric // to a chain of basic blocks, with loading element one-by-one if
390e8d8bef9SDimitry Andric // the appropriate mask bit is set
391e8d8bef9SDimitry Andric //
392e8d8bef9SDimitry Andric // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
393e8d8bef9SDimitry Andric // %Mask0 = extractelement <16 x i1> %Mask, i32 0
394e8d8bef9SDimitry Andric // br i1 %Mask0, label %cond.load, label %else
395e8d8bef9SDimitry Andric //
396e8d8bef9SDimitry Andric // cond.load:
397e8d8bef9SDimitry Andric // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
398e8d8bef9SDimitry Andric // %Load0 = load i32, i32* %Ptr0, align 4
399e8d8bef9SDimitry Andric // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
400e8d8bef9SDimitry Andric // br label %else
401e8d8bef9SDimitry Andric //
402e8d8bef9SDimitry Andric // else:
403e8d8bef9SDimitry Andric // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
404e8d8bef9SDimitry Andric // %Mask1 = extractelement <16 x i1> %Mask, i32 1
405e8d8bef9SDimitry Andric // br i1 %Mask1, label %cond.load1, label %else2
406e8d8bef9SDimitry Andric //
407e8d8bef9SDimitry Andric // cond.load1:
408e8d8bef9SDimitry Andric // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
409e8d8bef9SDimitry Andric // %Load1 = load i32, i32* %Ptr1, align 4
410e8d8bef9SDimitry Andric // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
411e8d8bef9SDimitry Andric // br label %else2
412e8d8bef9SDimitry Andric // . . .
413e8d8bef9SDimitry Andric // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
414e8d8bef9SDimitry Andric // ret <16 x i32> %Result
415fe6060f1SDimitry Andric static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
416fe6060f1SDimitry Andric                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
417e8d8bef9SDimitry Andric   Value *Ptrs = CI->getArgOperand(0);
418e8d8bef9SDimitry Andric   Value *Alignment = CI->getArgOperand(1);
419e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(2);
420e8d8bef9SDimitry Andric   Value *Src0 = CI->getArgOperand(3);
421e8d8bef9SDimitry Andric 
422e8d8bef9SDimitry Andric   auto *VecType = cast<FixedVectorType>(CI->getType());
423e8d8bef9SDimitry Andric   Type *EltTy = VecType->getElementType();
424e8d8bef9SDimitry Andric 
425e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
426e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
427e8d8bef9SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
428e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
429e8d8bef9SDimitry Andric   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
430e8d8bef9SDimitry Andric 
431e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
432e8d8bef9SDimitry Andric 
433e8d8bef9SDimitry Andric   // The result vector
434e8d8bef9SDimitry Andric   Value *VResult = Src0;
435e8d8bef9SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
436e8d8bef9SDimitry Andric 
437e8d8bef9SDimitry Andric   // Shorten the way if the mask is a vector of constants.
438e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
439e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
440e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
441e8d8bef9SDimitry Andric         continue;
442e8d8bef9SDimitry Andric       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
443e8d8bef9SDimitry Andric       LoadInst *Load =
444e8d8bef9SDimitry Andric           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
445e8d8bef9SDimitry Andric       VResult =
446e8d8bef9SDimitry Andric           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
447e8d8bef9SDimitry Andric     }
448e8d8bef9SDimitry Andric     CI->replaceAllUsesWith(VResult);
449e8d8bef9SDimitry Andric     CI->eraseFromParent();
450e8d8bef9SDimitry Andric     return;
451e8d8bef9SDimitry Andric   }
452e8d8bef9SDimitry Andric 
453e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
454e8d8bef9SDimitry Andric   // better results on X86 at least.
455e8d8bef9SDimitry Andric   Value *SclrMask;
456e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
457e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
458e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
459e8d8bef9SDimitry Andric   }
460e8d8bef9SDimitry Andric 
461e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
462e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
463e8d8bef9SDimitry Andric     //
464e8d8bef9SDimitry Andric     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
465e8d8bef9SDimitry Andric     //  %cond = icmp ne i16 %mask_1, 0
466e8d8bef9SDimitry Andric     //  br i1 %Mask1, label %cond.load, label %else
467e8d8bef9SDimitry Andric     //
468e8d8bef9SDimitry Andric 
469e8d8bef9SDimitry Andric     Value *Predicate;
470e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
471fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
472fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
473e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
474e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
475e8d8bef9SDimitry Andric     } else {
476e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
477e8d8bef9SDimitry Andric     }
478e8d8bef9SDimitry Andric 
479e8d8bef9SDimitry Andric     // Create "cond" block
480e8d8bef9SDimitry Andric     //
481e8d8bef9SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
482e8d8bef9SDimitry Andric     //  %Elt = load i32* %EltAddr
483e8d8bef9SDimitry Andric     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
484e8d8bef9SDimitry Andric     //
485fe6060f1SDimitry Andric     Instruction *ThenTerm =
486fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
487fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
488e8d8bef9SDimitry Andric 
489fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
490fe6060f1SDimitry Andric     CondBlock->setName("cond.load");
491fe6060f1SDimitry Andric 
492fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
493e8d8bef9SDimitry Andric     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
494e8d8bef9SDimitry Andric     LoadInst *Load =
495e8d8bef9SDimitry Andric         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
496e8d8bef9SDimitry Andric     Value *NewVResult =
497e8d8bef9SDimitry Andric         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
498e8d8bef9SDimitry Andric 
499e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
500fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
501fe6060f1SDimitry Andric     NewIfBlock->setName("else");
502e8d8bef9SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
503e8d8bef9SDimitry Andric     IfBlock = NewIfBlock;
504e8d8bef9SDimitry Andric 
505fe6060f1SDimitry Andric     // Create the phi to join the new and previous value.
506fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
507e8d8bef9SDimitry Andric     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
508e8d8bef9SDimitry Andric     Phi->addIncoming(NewVResult, CondBlock);
509e8d8bef9SDimitry Andric     Phi->addIncoming(VResult, PrevIfBlock);
510e8d8bef9SDimitry Andric     VResult = Phi;
511e8d8bef9SDimitry Andric   }
512e8d8bef9SDimitry Andric 
513e8d8bef9SDimitry Andric   CI->replaceAllUsesWith(VResult);
514e8d8bef9SDimitry Andric   CI->eraseFromParent();
515e8d8bef9SDimitry Andric 
516e8d8bef9SDimitry Andric   ModifiedDT = true;
517e8d8bef9SDimitry Andric }
518e8d8bef9SDimitry Andric 
519e8d8bef9SDimitry Andric // Translate a masked scatter intrinsic, like
520e8d8bef9SDimitry Andric // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
521e8d8bef9SDimitry Andric //                                  <16 x i1> %Mask)
522e8d8bef9SDimitry Andric // to a chain of basic blocks, that stores element one-by-one if
523e8d8bef9SDimitry Andric // the appropriate mask bit is set.
524e8d8bef9SDimitry Andric //
525e8d8bef9SDimitry Andric // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
526e8d8bef9SDimitry Andric // %Mask0 = extractelement <16 x i1> %Mask, i32 0
527e8d8bef9SDimitry Andric // br i1 %Mask0, label %cond.store, label %else
528e8d8bef9SDimitry Andric //
529e8d8bef9SDimitry Andric // cond.store:
530e8d8bef9SDimitry Andric // %Elt0 = extractelement <16 x i32> %Src, i32 0
531e8d8bef9SDimitry Andric // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
532e8d8bef9SDimitry Andric // store i32 %Elt0, i32* %Ptr0, align 4
533e8d8bef9SDimitry Andric // br label %else
534e8d8bef9SDimitry Andric //
535e8d8bef9SDimitry Andric // else:
536e8d8bef9SDimitry Andric // %Mask1 = extractelement <16 x i1> %Mask, i32 1
537e8d8bef9SDimitry Andric // br i1 %Mask1, label %cond.store1, label %else2
538e8d8bef9SDimitry Andric //
539e8d8bef9SDimitry Andric // cond.store1:
540e8d8bef9SDimitry Andric // %Elt1 = extractelement <16 x i32> %Src, i32 1
541e8d8bef9SDimitry Andric // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
542e8d8bef9SDimitry Andric // store i32 %Elt1, i32* %Ptr1, align 4
543e8d8bef9SDimitry Andric // br label %else2
544e8d8bef9SDimitry Andric //   . . .
545fe6060f1SDimitry Andric static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
546fe6060f1SDimitry Andric                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
547e8d8bef9SDimitry Andric   Value *Src = CI->getArgOperand(0);
548e8d8bef9SDimitry Andric   Value *Ptrs = CI->getArgOperand(1);
549e8d8bef9SDimitry Andric   Value *Alignment = CI->getArgOperand(2);
550e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(3);
551e8d8bef9SDimitry Andric 
552e8d8bef9SDimitry Andric   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
553e8d8bef9SDimitry Andric 
554e8d8bef9SDimitry Andric   assert(
555e8d8bef9SDimitry Andric       isa<VectorType>(Ptrs->getType()) &&
556e8d8bef9SDimitry Andric       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
557e8d8bef9SDimitry Andric       "Vector of pointers is expected in masked scatter intrinsic");
558e8d8bef9SDimitry Andric 
559e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
560e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
561e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
562e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
563e8d8bef9SDimitry Andric 
564e8d8bef9SDimitry Andric   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
565e8d8bef9SDimitry Andric   unsigned VectorWidth = SrcFVTy->getNumElements();
566e8d8bef9SDimitry Andric 
567e8d8bef9SDimitry Andric   // Shorten the way if the mask is a vector of constants.
568e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
569e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
570e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
571e8d8bef9SDimitry Andric         continue;
572e8d8bef9SDimitry Andric       Value *OneElt =
573e8d8bef9SDimitry Andric           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
574e8d8bef9SDimitry Andric       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
575e8d8bef9SDimitry Andric       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576e8d8bef9SDimitry Andric     }
577e8d8bef9SDimitry Andric     CI->eraseFromParent();
578e8d8bef9SDimitry Andric     return;
579e8d8bef9SDimitry Andric   }
580e8d8bef9SDimitry Andric 
581e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
582e8d8bef9SDimitry Andric   // better results on X86 at least.
583e8d8bef9SDimitry Andric   Value *SclrMask;
584e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
585e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
586e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
587e8d8bef9SDimitry Andric   }
588e8d8bef9SDimitry Andric 
589e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
590e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
591e8d8bef9SDimitry Andric     //
592e8d8bef9SDimitry Andric     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
593e8d8bef9SDimitry Andric     //  %cond = icmp ne i16 %mask_1, 0
594e8d8bef9SDimitry Andric     //  br i1 %Mask1, label %cond.store, label %else
595e8d8bef9SDimitry Andric     //
596e8d8bef9SDimitry Andric     Value *Predicate;
597e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
598fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
599fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
600e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
601e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
602e8d8bef9SDimitry Andric     } else {
603e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
604e8d8bef9SDimitry Andric     }
605e8d8bef9SDimitry Andric 
606e8d8bef9SDimitry Andric     // Create "cond" block
607e8d8bef9SDimitry Andric     //
608e8d8bef9SDimitry Andric     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
609e8d8bef9SDimitry Andric     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
610e8d8bef9SDimitry Andric     //  %store i32 %Elt1, i32* %Ptr1
611e8d8bef9SDimitry Andric     //
612fe6060f1SDimitry Andric     Instruction *ThenTerm =
613fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
614fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
615e8d8bef9SDimitry Andric 
616fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
617fe6060f1SDimitry Andric     CondBlock->setName("cond.store");
618fe6060f1SDimitry Andric 
619fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
620e8d8bef9SDimitry Andric     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
621e8d8bef9SDimitry Andric     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
622e8d8bef9SDimitry Andric     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
623e8d8bef9SDimitry Andric 
624e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
625fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
626fe6060f1SDimitry Andric     NewIfBlock->setName("else");
627fe6060f1SDimitry Andric 
628fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
629e8d8bef9SDimitry Andric   }
630e8d8bef9SDimitry Andric   CI->eraseFromParent();
631e8d8bef9SDimitry Andric 
632e8d8bef9SDimitry Andric   ModifiedDT = true;
633e8d8bef9SDimitry Andric }
634e8d8bef9SDimitry Andric 
635fe6060f1SDimitry Andric static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
636fe6060f1SDimitry Andric                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
637e8d8bef9SDimitry Andric   Value *Ptr = CI->getArgOperand(0);
638e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(1);
639e8d8bef9SDimitry Andric   Value *PassThru = CI->getArgOperand(2);
640e8d8bef9SDimitry Andric 
641e8d8bef9SDimitry Andric   auto *VecType = cast<FixedVectorType>(CI->getType());
642e8d8bef9SDimitry Andric 
643e8d8bef9SDimitry Andric   Type *EltTy = VecType->getElementType();
644e8d8bef9SDimitry Andric 
645e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
646e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
647e8d8bef9SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
648e8d8bef9SDimitry Andric 
649e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
650e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
651e8d8bef9SDimitry Andric 
652e8d8bef9SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
653e8d8bef9SDimitry Andric 
654e8d8bef9SDimitry Andric   // The result vector
655e8d8bef9SDimitry Andric   Value *VResult = PassThru;
656e8d8bef9SDimitry Andric 
657e8d8bef9SDimitry Andric   // Shorten the way if the mask is a vector of constants.
658e8d8bef9SDimitry Andric   // Create a build_vector pattern, with loads/undefs as necessary and then
659e8d8bef9SDimitry Andric   // shuffle blend with the pass through value.
660e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
661e8d8bef9SDimitry Andric     unsigned MemIndex = 0;
662e8d8bef9SDimitry Andric     VResult = UndefValue::get(VecType);
663e8d8bef9SDimitry Andric     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
664e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
665e8d8bef9SDimitry Andric       Value *InsertElt;
666e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
667e8d8bef9SDimitry Andric         InsertElt = UndefValue::get(EltTy);
668e8d8bef9SDimitry Andric         ShuffleMask[Idx] = Idx + VectorWidth;
669e8d8bef9SDimitry Andric       } else {
670e8d8bef9SDimitry Andric         Value *NewPtr =
671e8d8bef9SDimitry Andric             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
672e8d8bef9SDimitry Andric         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
673e8d8bef9SDimitry Andric                                               "Load" + Twine(Idx));
674e8d8bef9SDimitry Andric         ShuffleMask[Idx] = Idx;
675e8d8bef9SDimitry Andric         ++MemIndex;
676e8d8bef9SDimitry Andric       }
677e8d8bef9SDimitry Andric       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
678e8d8bef9SDimitry Andric                                             "Res" + Twine(Idx));
679e8d8bef9SDimitry Andric     }
680e8d8bef9SDimitry Andric     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
681e8d8bef9SDimitry Andric     CI->replaceAllUsesWith(VResult);
682e8d8bef9SDimitry Andric     CI->eraseFromParent();
683e8d8bef9SDimitry Andric     return;
684e8d8bef9SDimitry Andric   }
685e8d8bef9SDimitry Andric 
686e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
687e8d8bef9SDimitry Andric   // better results on X86 at least.
688e8d8bef9SDimitry Andric   Value *SclrMask;
689e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
690e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
691e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
692e8d8bef9SDimitry Andric   }
693e8d8bef9SDimitry Andric 
694e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
695e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
696e8d8bef9SDimitry Andric     //
697e8d8bef9SDimitry Andric     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
698e8d8bef9SDimitry Andric     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
699e8d8bef9SDimitry Andric     //  br i1 %mask_1, label %cond.load, label %else
700e8d8bef9SDimitry Andric     //
701e8d8bef9SDimitry Andric 
702e8d8bef9SDimitry Andric     Value *Predicate;
703e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
704fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
705fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
706e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
707e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
708e8d8bef9SDimitry Andric     } else {
709e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
710e8d8bef9SDimitry Andric     }
711e8d8bef9SDimitry Andric 
712e8d8bef9SDimitry Andric     // Create "cond" block
713e8d8bef9SDimitry Andric     //
714e8d8bef9SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
715e8d8bef9SDimitry Andric     //  %Elt = load i32* %EltAddr
716e8d8bef9SDimitry Andric     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
717e8d8bef9SDimitry Andric     //
718fe6060f1SDimitry Andric     Instruction *ThenTerm =
719fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
720fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
721e8d8bef9SDimitry Andric 
722fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
723fe6060f1SDimitry Andric     CondBlock->setName("cond.load");
724fe6060f1SDimitry Andric 
725fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
726e8d8bef9SDimitry Andric     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
727e8d8bef9SDimitry Andric     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
728e8d8bef9SDimitry Andric 
729e8d8bef9SDimitry Andric     // Move the pointer if there are more blocks to come.
730e8d8bef9SDimitry Andric     Value *NewPtr;
731e8d8bef9SDimitry Andric     if ((Idx + 1) != VectorWidth)
732e8d8bef9SDimitry Andric       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
733e8d8bef9SDimitry Andric 
734e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
735fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
736fe6060f1SDimitry Andric     NewIfBlock->setName("else");
737e8d8bef9SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
738e8d8bef9SDimitry Andric     IfBlock = NewIfBlock;
739e8d8bef9SDimitry Andric 
740e8d8bef9SDimitry Andric     // Create the phi to join the new and previous value.
741fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
742e8d8bef9SDimitry Andric     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
743e8d8bef9SDimitry Andric     ResultPhi->addIncoming(NewVResult, CondBlock);
744e8d8bef9SDimitry Andric     ResultPhi->addIncoming(VResult, PrevIfBlock);
745e8d8bef9SDimitry Andric     VResult = ResultPhi;
746e8d8bef9SDimitry Andric 
747e8d8bef9SDimitry Andric     // Add a PHI for the pointer if this isn't the last iteration.
748e8d8bef9SDimitry Andric     if ((Idx + 1) != VectorWidth) {
749e8d8bef9SDimitry Andric       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
750e8d8bef9SDimitry Andric       PtrPhi->addIncoming(NewPtr, CondBlock);
751e8d8bef9SDimitry Andric       PtrPhi->addIncoming(Ptr, PrevIfBlock);
752e8d8bef9SDimitry Andric       Ptr = PtrPhi;
753e8d8bef9SDimitry Andric     }
754e8d8bef9SDimitry Andric   }
755e8d8bef9SDimitry Andric 
756e8d8bef9SDimitry Andric   CI->replaceAllUsesWith(VResult);
757e8d8bef9SDimitry Andric   CI->eraseFromParent();
758e8d8bef9SDimitry Andric 
759e8d8bef9SDimitry Andric   ModifiedDT = true;
760e8d8bef9SDimitry Andric }
761e8d8bef9SDimitry Andric 
762fe6060f1SDimitry Andric static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
763fe6060f1SDimitry Andric                                          DomTreeUpdater *DTU,
764fe6060f1SDimitry Andric                                          bool &ModifiedDT) {
765e8d8bef9SDimitry Andric   Value *Src = CI->getArgOperand(0);
766e8d8bef9SDimitry Andric   Value *Ptr = CI->getArgOperand(1);
767e8d8bef9SDimitry Andric   Value *Mask = CI->getArgOperand(2);
768e8d8bef9SDimitry Andric 
769e8d8bef9SDimitry Andric   auto *VecType = cast<FixedVectorType>(Src->getType());
770e8d8bef9SDimitry Andric 
771e8d8bef9SDimitry Andric   IRBuilder<> Builder(CI->getContext());
772e8d8bef9SDimitry Andric   Instruction *InsertPt = CI;
773e8d8bef9SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
774e8d8bef9SDimitry Andric 
775e8d8bef9SDimitry Andric   Builder.SetInsertPoint(InsertPt);
776e8d8bef9SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
777e8d8bef9SDimitry Andric 
778e8d8bef9SDimitry Andric   Type *EltTy = VecType->getElementType();
779e8d8bef9SDimitry Andric 
780e8d8bef9SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
781e8d8bef9SDimitry Andric 
782e8d8bef9SDimitry Andric   // Shorten the way if the mask is a vector of constants.
783e8d8bef9SDimitry Andric   if (isConstantIntVector(Mask)) {
784e8d8bef9SDimitry Andric     unsigned MemIndex = 0;
785e8d8bef9SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
786e8d8bef9SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
787e8d8bef9SDimitry Andric         continue;
788e8d8bef9SDimitry Andric       Value *OneElt =
789e8d8bef9SDimitry Andric           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
790e8d8bef9SDimitry Andric       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
791e8d8bef9SDimitry Andric       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
792e8d8bef9SDimitry Andric       ++MemIndex;
793e8d8bef9SDimitry Andric     }
794e8d8bef9SDimitry Andric     CI->eraseFromParent();
795e8d8bef9SDimitry Andric     return;
796e8d8bef9SDimitry Andric   }
797e8d8bef9SDimitry Andric 
798e8d8bef9SDimitry Andric   // If the mask is not v1i1, use scalar bit test operations. This generates
799e8d8bef9SDimitry Andric   // better results on X86 at least.
800e8d8bef9SDimitry Andric   Value *SclrMask;
801e8d8bef9SDimitry Andric   if (VectorWidth != 1) {
802e8d8bef9SDimitry Andric     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
803e8d8bef9SDimitry Andric     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
804e8d8bef9SDimitry Andric   }
805e8d8bef9SDimitry Andric 
806e8d8bef9SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
807e8d8bef9SDimitry Andric     // Fill the "else" block, created in the previous iteration
808e8d8bef9SDimitry Andric     //
809e8d8bef9SDimitry Andric     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
810e8d8bef9SDimitry Andric     //  br i1 %mask_1, label %cond.store, label %else
811e8d8bef9SDimitry Andric     //
812e8d8bef9SDimitry Andric     Value *Predicate;
813e8d8bef9SDimitry Andric     if (VectorWidth != 1) {
814fe6060f1SDimitry Andric       Value *Mask = Builder.getInt(APInt::getOneBitSet(
815fe6060f1SDimitry Andric           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
816e8d8bef9SDimitry Andric       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
817e8d8bef9SDimitry Andric                                        Builder.getIntN(VectorWidth, 0));
818e8d8bef9SDimitry Andric     } else {
819e8d8bef9SDimitry Andric       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
820e8d8bef9SDimitry Andric     }
821e8d8bef9SDimitry Andric 
822e8d8bef9SDimitry Andric     // Create "cond" block
823e8d8bef9SDimitry Andric     //
824e8d8bef9SDimitry Andric     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
825e8d8bef9SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
826e8d8bef9SDimitry Andric     //  %store i32 %OneElt, i32* %EltAddr
827e8d8bef9SDimitry Andric     //
828fe6060f1SDimitry Andric     Instruction *ThenTerm =
829fe6060f1SDimitry Andric         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
830fe6060f1SDimitry Andric                                   /*BranchWeights=*/nullptr, DTU);
831e8d8bef9SDimitry Andric 
832fe6060f1SDimitry Andric     BasicBlock *CondBlock = ThenTerm->getParent();
833fe6060f1SDimitry Andric     CondBlock->setName("cond.store");
834fe6060f1SDimitry Andric 
835fe6060f1SDimitry Andric     Builder.SetInsertPoint(CondBlock->getTerminator());
836e8d8bef9SDimitry Andric     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
837e8d8bef9SDimitry Andric     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
838e8d8bef9SDimitry Andric 
839e8d8bef9SDimitry Andric     // Move the pointer if there are more blocks to come.
840e8d8bef9SDimitry Andric     Value *NewPtr;
841e8d8bef9SDimitry Andric     if ((Idx + 1) != VectorWidth)
842e8d8bef9SDimitry Andric       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
843e8d8bef9SDimitry Andric 
844e8d8bef9SDimitry Andric     // Create "else" block, fill it in the next iteration
845fe6060f1SDimitry Andric     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
846fe6060f1SDimitry Andric     NewIfBlock->setName("else");
847e8d8bef9SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
848e8d8bef9SDimitry Andric     IfBlock = NewIfBlock;
849e8d8bef9SDimitry Andric 
850fe6060f1SDimitry Andric     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
851fe6060f1SDimitry Andric 
852e8d8bef9SDimitry Andric     // Add a PHI for the pointer if this isn't the last iteration.
853e8d8bef9SDimitry Andric     if ((Idx + 1) != VectorWidth) {
854e8d8bef9SDimitry Andric       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
855e8d8bef9SDimitry Andric       PtrPhi->addIncoming(NewPtr, CondBlock);
856e8d8bef9SDimitry Andric       PtrPhi->addIncoming(Ptr, PrevIfBlock);
857e8d8bef9SDimitry Andric       Ptr = PtrPhi;
858e8d8bef9SDimitry Andric     }
859e8d8bef9SDimitry Andric   }
860e8d8bef9SDimitry Andric   CI->eraseFromParent();
861e8d8bef9SDimitry Andric 
862e8d8bef9SDimitry Andric   ModifiedDT = true;
863e8d8bef9SDimitry Andric }
864e8d8bef9SDimitry Andric 
865fe6060f1SDimitry Andric static bool runImpl(Function &F, const TargetTransformInfo &TTI,
866fe6060f1SDimitry Andric                     DominatorTree *DT) {
867fe6060f1SDimitry Andric   Optional<DomTreeUpdater> DTU;
868fe6060f1SDimitry Andric   if (DT)
869fe6060f1SDimitry Andric     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
870fe6060f1SDimitry Andric 
871e8d8bef9SDimitry Andric   bool EverMadeChange = false;
872e8d8bef9SDimitry Andric   bool MadeChange = true;
873e8d8bef9SDimitry Andric   auto &DL = F.getParent()->getDataLayout();
874e8d8bef9SDimitry Andric   while (MadeChange) {
875e8d8bef9SDimitry Andric     MadeChange = false;
876*349cc55cSDimitry Andric     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
877e8d8bef9SDimitry Andric       bool ModifiedDTOnIteration = false;
878*349cc55cSDimitry Andric       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
879fe6060f1SDimitry Andric                                   DTU.hasValue() ? DTU.getPointer() : nullptr);
880fe6060f1SDimitry Andric 
881e8d8bef9SDimitry Andric       // Restart BB iteration if the dominator tree of the Function was changed
882e8d8bef9SDimitry Andric       if (ModifiedDTOnIteration)
883e8d8bef9SDimitry Andric         break;
884e8d8bef9SDimitry Andric     }
885e8d8bef9SDimitry Andric 
886e8d8bef9SDimitry Andric     EverMadeChange |= MadeChange;
887e8d8bef9SDimitry Andric   }
888e8d8bef9SDimitry Andric   return EverMadeChange;
889e8d8bef9SDimitry Andric }
890e8d8bef9SDimitry Andric 
891e8d8bef9SDimitry Andric bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
892e8d8bef9SDimitry Andric   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
893fe6060f1SDimitry Andric   DominatorTree *DT = nullptr;
894fe6060f1SDimitry Andric   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
895fe6060f1SDimitry Andric     DT = &DTWP->getDomTree();
896fe6060f1SDimitry Andric   return runImpl(F, TTI, DT);
897e8d8bef9SDimitry Andric }
898e8d8bef9SDimitry Andric 
899e8d8bef9SDimitry Andric PreservedAnalyses
900e8d8bef9SDimitry Andric ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
901e8d8bef9SDimitry Andric   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
902fe6060f1SDimitry Andric   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
903fe6060f1SDimitry Andric   if (!runImpl(F, TTI, DT))
904e8d8bef9SDimitry Andric     return PreservedAnalyses::all();
905e8d8bef9SDimitry Andric   PreservedAnalyses PA;
906e8d8bef9SDimitry Andric   PA.preserve<TargetIRAnalysis>();
907fe6060f1SDimitry Andric   PA.preserve<DominatorTreeAnalysis>();
908e8d8bef9SDimitry Andric   return PA;
909e8d8bef9SDimitry Andric }
910e8d8bef9SDimitry Andric 
911e8d8bef9SDimitry Andric static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
912fe6060f1SDimitry Andric                           const TargetTransformInfo &TTI, const DataLayout &DL,
913fe6060f1SDimitry Andric                           DomTreeUpdater *DTU) {
914e8d8bef9SDimitry Andric   bool MadeChange = false;
915e8d8bef9SDimitry Andric 
916e8d8bef9SDimitry Andric   BasicBlock::iterator CurInstIterator = BB.begin();
917e8d8bef9SDimitry Andric   while (CurInstIterator != BB.end()) {
918e8d8bef9SDimitry Andric     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
919fe6060f1SDimitry Andric       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
920e8d8bef9SDimitry Andric     if (ModifiedDT)
921e8d8bef9SDimitry Andric       return true;
922e8d8bef9SDimitry Andric   }
923e8d8bef9SDimitry Andric 
924e8d8bef9SDimitry Andric   return MadeChange;
925e8d8bef9SDimitry Andric }
926e8d8bef9SDimitry Andric 
927e8d8bef9SDimitry Andric static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
928e8d8bef9SDimitry Andric                              const TargetTransformInfo &TTI,
929fe6060f1SDimitry Andric                              const DataLayout &DL, DomTreeUpdater *DTU) {
930e8d8bef9SDimitry Andric   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
931e8d8bef9SDimitry Andric   if (II) {
932e8d8bef9SDimitry Andric     // The scalarization code below does not work for scalable vectors.
933e8d8bef9SDimitry Andric     if (isa<ScalableVectorType>(II->getType()) ||
934*349cc55cSDimitry Andric         any_of(II->args(),
935e8d8bef9SDimitry Andric                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
936e8d8bef9SDimitry Andric       return false;
937e8d8bef9SDimitry Andric 
938e8d8bef9SDimitry Andric     switch (II->getIntrinsicID()) {
939e8d8bef9SDimitry Andric     default:
940e8d8bef9SDimitry Andric       break;
941e8d8bef9SDimitry Andric     case Intrinsic::masked_load:
942e8d8bef9SDimitry Andric       // Scalarize unsupported vector masked load
943e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedLoad(
944e8d8bef9SDimitry Andric               CI->getType(),
945e8d8bef9SDimitry Andric               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
946e8d8bef9SDimitry Andric         return false;
947fe6060f1SDimitry Andric       scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
948e8d8bef9SDimitry Andric       return true;
949e8d8bef9SDimitry Andric     case Intrinsic::masked_store:
950e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedStore(
951e8d8bef9SDimitry Andric               CI->getArgOperand(0)->getType(),
952e8d8bef9SDimitry Andric               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
953e8d8bef9SDimitry Andric         return false;
954fe6060f1SDimitry Andric       scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
955e8d8bef9SDimitry Andric       return true;
956e8d8bef9SDimitry Andric     case Intrinsic::masked_gather: {
957fe6060f1SDimitry Andric       MaybeAlign MA =
958fe6060f1SDimitry Andric           cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
959e8d8bef9SDimitry Andric       Type *LoadTy = CI->getType();
960fe6060f1SDimitry Andric       Align Alignment = DL.getValueOrABITypeAlignment(MA,
961fe6060f1SDimitry Andric                                                       LoadTy->getScalarType());
962e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedGather(LoadTy, Alignment))
963e8d8bef9SDimitry Andric         return false;
964fe6060f1SDimitry Andric       scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
965e8d8bef9SDimitry Andric       return true;
966e8d8bef9SDimitry Andric     }
967e8d8bef9SDimitry Andric     case Intrinsic::masked_scatter: {
968fe6060f1SDimitry Andric       MaybeAlign MA =
969fe6060f1SDimitry Andric           cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
970e8d8bef9SDimitry Andric       Type *StoreTy = CI->getArgOperand(0)->getType();
971fe6060f1SDimitry Andric       Align Alignment = DL.getValueOrABITypeAlignment(MA,
972fe6060f1SDimitry Andric                                                       StoreTy->getScalarType());
973e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
974e8d8bef9SDimitry Andric         return false;
975fe6060f1SDimitry Andric       scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
976e8d8bef9SDimitry Andric       return true;
977e8d8bef9SDimitry Andric     }
978e8d8bef9SDimitry Andric     case Intrinsic::masked_expandload:
979e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
980e8d8bef9SDimitry Andric         return false;
981fe6060f1SDimitry Andric       scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
982e8d8bef9SDimitry Andric       return true;
983e8d8bef9SDimitry Andric     case Intrinsic::masked_compressstore:
984e8d8bef9SDimitry Andric       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
985e8d8bef9SDimitry Andric         return false;
986fe6060f1SDimitry Andric       scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
987e8d8bef9SDimitry Andric       return true;
988e8d8bef9SDimitry Andric     }
989e8d8bef9SDimitry Andric   }
990e8d8bef9SDimitry Andric 
991e8d8bef9SDimitry Andric   return false;
992e8d8bef9SDimitry Andric }
993