xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
1fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2fcaf7f86SDimitry Andric //
3fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fcaf7f86SDimitry Andric //
7fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
8fcaf7f86SDimitry Andric //
9fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments
10fcaf7f86SDimitry Andric // and/or return value. Also it substitutes some llvm intrinsic calls by
11fcaf7f86SDimitry Andric // function calls, generating these functions as the translator does.
12fcaf7f86SDimitry Andric //
13fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify
14fcaf7f86SDimitry Andric // GVs/functions.
15fcaf7f86SDimitry Andric //
16fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
17fcaf7f86SDimitry Andric 
18fcaf7f86SDimitry Andric #include "SPIRV.h"
19fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h"
20fcaf7f86SDimitry Andric #include "SPIRVUtils.h"
21*bdd1243dSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
22fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h"
23fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
24fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
25fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
26fcaf7f86SDimitry Andric 
27fcaf7f86SDimitry Andric using namespace llvm;
28fcaf7f86SDimitry Andric 
29fcaf7f86SDimitry Andric namespace llvm {
30fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
31fcaf7f86SDimitry Andric }
32fcaf7f86SDimitry Andric 
33fcaf7f86SDimitry Andric namespace {
34fcaf7f86SDimitry Andric 
35fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass {
36fcaf7f86SDimitry Andric   Function *processFunctionSignature(Function *F);
37fcaf7f86SDimitry Andric 
38fcaf7f86SDimitry Andric public:
39fcaf7f86SDimitry Andric   static char ID;
40fcaf7f86SDimitry Andric   SPIRVPrepareFunctions() : ModulePass(ID) {
41fcaf7f86SDimitry Andric     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
42fcaf7f86SDimitry Andric   }
43fcaf7f86SDimitry Andric 
44fcaf7f86SDimitry Andric   bool runOnModule(Module &M) override;
45fcaf7f86SDimitry Andric 
46fcaf7f86SDimitry Andric   StringRef getPassName() const override { return "SPIRV prepare functions"; }
47fcaf7f86SDimitry Andric 
48fcaf7f86SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
49fcaf7f86SDimitry Andric     ModulePass::getAnalysisUsage(AU);
50fcaf7f86SDimitry Andric   }
51fcaf7f86SDimitry Andric };
52fcaf7f86SDimitry Andric 
53fcaf7f86SDimitry Andric } // namespace
54fcaf7f86SDimitry Andric 
55fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0;
56fcaf7f86SDimitry Andric 
57fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
58fcaf7f86SDimitry Andric                 "SPIRV prepare functions", false, false)
59fcaf7f86SDimitry Andric 
60fcaf7f86SDimitry Andric Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
61fcaf7f86SDimitry Andric   IRBuilder<> B(F->getContext());
62fcaf7f86SDimitry Andric 
63fcaf7f86SDimitry Andric   bool IsRetAggr = F->getReturnType()->isAggregateType();
64fcaf7f86SDimitry Andric   bool HasAggrArg =
65fcaf7f86SDimitry Andric       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
66fcaf7f86SDimitry Andric         return Arg.getType()->isAggregateType();
67fcaf7f86SDimitry Andric       });
68fcaf7f86SDimitry Andric   bool DoClone = IsRetAggr || HasAggrArg;
69fcaf7f86SDimitry Andric   if (!DoClone)
70fcaf7f86SDimitry Andric     return F;
71fcaf7f86SDimitry Andric   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
72fcaf7f86SDimitry Andric   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
73fcaf7f86SDimitry Andric   if (IsRetAggr)
74fcaf7f86SDimitry Andric     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
75fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
76fcaf7f86SDimitry Andric   for (const auto &Arg : F->args()) {
77fcaf7f86SDimitry Andric     if (Arg.getType()->isAggregateType()) {
78fcaf7f86SDimitry Andric       ArgTypes.push_back(B.getInt32Ty());
79fcaf7f86SDimitry Andric       ChangedTypes.push_back(
80fcaf7f86SDimitry Andric           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
81fcaf7f86SDimitry Andric     } else
82fcaf7f86SDimitry Andric       ArgTypes.push_back(Arg.getType());
83fcaf7f86SDimitry Andric   }
84fcaf7f86SDimitry Andric   FunctionType *NewFTy =
85fcaf7f86SDimitry Andric       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
86fcaf7f86SDimitry Andric   Function *NewF =
87fcaf7f86SDimitry Andric       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
88fcaf7f86SDimitry Andric 
89fcaf7f86SDimitry Andric   ValueToValueMapTy VMap;
90fcaf7f86SDimitry Andric   auto NewFArgIt = NewF->arg_begin();
91fcaf7f86SDimitry Andric   for (auto &Arg : F->args()) {
92fcaf7f86SDimitry Andric     StringRef ArgName = Arg.getName();
93fcaf7f86SDimitry Andric     NewFArgIt->setName(ArgName);
94fcaf7f86SDimitry Andric     VMap[&Arg] = &(*NewFArgIt++);
95fcaf7f86SDimitry Andric   }
96fcaf7f86SDimitry Andric   SmallVector<ReturnInst *, 8> Returns;
97fcaf7f86SDimitry Andric 
98fcaf7f86SDimitry Andric   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
99fcaf7f86SDimitry Andric                     Returns);
100fcaf7f86SDimitry Andric   NewF->takeName(F);
101fcaf7f86SDimitry Andric 
102fcaf7f86SDimitry Andric   NamedMDNode *FuncMD =
103fcaf7f86SDimitry Andric       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
104fcaf7f86SDimitry Andric   SmallVector<Metadata *, 2> MDArgs;
105fcaf7f86SDimitry Andric   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
106fcaf7f86SDimitry Andric   for (auto &ChangedTyP : ChangedTypes)
107fcaf7f86SDimitry Andric     MDArgs.push_back(MDNode::get(
108fcaf7f86SDimitry Andric         B.getContext(),
109fcaf7f86SDimitry Andric         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
110fcaf7f86SDimitry Andric          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
111fcaf7f86SDimitry Andric   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
112fcaf7f86SDimitry Andric   FuncMD->addOperand(ThisFuncMD);
113fcaf7f86SDimitry Andric 
114fcaf7f86SDimitry Andric   for (auto *U : make_early_inc_range(F->users())) {
115fcaf7f86SDimitry Andric     if (auto *CI = dyn_cast<CallInst>(U))
116fcaf7f86SDimitry Andric       CI->mutateFunctionType(NewF->getFunctionType());
117fcaf7f86SDimitry Andric     U->replaceUsesOfWith(F, NewF);
118fcaf7f86SDimitry Andric   }
119fcaf7f86SDimitry Andric   return NewF;
120fcaf7f86SDimitry Andric }
121fcaf7f86SDimitry Andric 
122fcaf7f86SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
123fcaf7f86SDimitry Andric   Function *IntrinsicFunc = II->getCalledFunction();
124fcaf7f86SDimitry Andric   assert(IntrinsicFunc && "Missing function");
125fcaf7f86SDimitry Andric   std::string FuncName = IntrinsicFunc->getName().str();
126fcaf7f86SDimitry Andric   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
127fcaf7f86SDimitry Andric   FuncName = "spirv." + FuncName;
128fcaf7f86SDimitry Andric   return FuncName;
129fcaf7f86SDimitry Andric }
130fcaf7f86SDimitry Andric 
131fcaf7f86SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy,
132fcaf7f86SDimitry Andric                                      ArrayRef<Type *> ArgTypes,
133fcaf7f86SDimitry Andric                                      StringRef Name) {
134fcaf7f86SDimitry Andric   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
135fcaf7f86SDimitry Andric   Function *F = M->getFunction(Name);
136fcaf7f86SDimitry Andric   if (F && F->getFunctionType() == FT)
137fcaf7f86SDimitry Andric     return F;
138fcaf7f86SDimitry Andric   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
139fcaf7f86SDimitry Andric   if (F)
140fcaf7f86SDimitry Andric     NewF->setDSOLocal(F->isDSOLocal());
141fcaf7f86SDimitry Andric   NewF->setCallingConv(CallingConv::SPIR_FUNC);
142fcaf7f86SDimitry Andric   return NewF;
143fcaf7f86SDimitry Andric }
144fcaf7f86SDimitry Andric 
145*bdd1243dSDimitry Andric static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) {
146*bdd1243dSDimitry Andric   // For @llvm.memset.* intrinsic cases with constant value and length arguments
147*bdd1243dSDimitry Andric   // are emulated via "storing" a constant array to the destination. For other
148*bdd1243dSDimitry Andric   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
149*bdd1243dSDimitry Andric   // intrinsic to a loop via expandMemSetAsLoop().
150*bdd1243dSDimitry Andric   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
151*bdd1243dSDimitry Andric     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
152*bdd1243dSDimitry Andric       return; // It is handled later using OpCopyMemorySized.
153*bdd1243dSDimitry Andric 
154*bdd1243dSDimitry Andric   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
155*bdd1243dSDimitry Andric   if (Intrinsic->isVolatile())
156*bdd1243dSDimitry Andric     FuncName += ".volatile";
157*bdd1243dSDimitry Andric   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
158*bdd1243dSDimitry Andric   Function *F = M->getFunction(FuncName);
159*bdd1243dSDimitry Andric   if (F) {
160*bdd1243dSDimitry Andric     Intrinsic->setCalledFunction(F);
161*bdd1243dSDimitry Andric     return;
162*bdd1243dSDimitry Andric   }
163*bdd1243dSDimitry Andric   // TODO copy arguments attributes: nocapture writeonly.
164*bdd1243dSDimitry Andric   FunctionCallee FC =
165*bdd1243dSDimitry Andric       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
166*bdd1243dSDimitry Andric   auto IntrinsicID = Intrinsic->getIntrinsicID();
167*bdd1243dSDimitry Andric   Intrinsic->setCalledFunction(FC);
168*bdd1243dSDimitry Andric 
169*bdd1243dSDimitry Andric   F = dyn_cast<Function>(FC.getCallee());
170*bdd1243dSDimitry Andric   assert(F && "Callee must be a function");
171*bdd1243dSDimitry Andric 
172*bdd1243dSDimitry Andric   switch (IntrinsicID) {
173*bdd1243dSDimitry Andric   case Intrinsic::memset: {
174*bdd1243dSDimitry Andric     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
175*bdd1243dSDimitry Andric     Argument *Dest = F->getArg(0);
176*bdd1243dSDimitry Andric     Argument *Val = F->getArg(1);
177*bdd1243dSDimitry Andric     Argument *Len = F->getArg(2);
178*bdd1243dSDimitry Andric     Argument *IsVolatile = F->getArg(3);
179*bdd1243dSDimitry Andric     Dest->setName("dest");
180*bdd1243dSDimitry Andric     Val->setName("val");
181*bdd1243dSDimitry Andric     Len->setName("len");
182*bdd1243dSDimitry Andric     IsVolatile->setName("isvolatile");
183*bdd1243dSDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
184*bdd1243dSDimitry Andric     IRBuilder<> IRB(EntryBB);
185*bdd1243dSDimitry Andric     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
186*bdd1243dSDimitry Andric                                     MSI->isVolatile());
187*bdd1243dSDimitry Andric     IRB.CreateRetVoid();
188*bdd1243dSDimitry Andric     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
189*bdd1243dSDimitry Andric     MemSet->eraseFromParent();
190*bdd1243dSDimitry Andric     break;
191*bdd1243dSDimitry Andric   }
192*bdd1243dSDimitry Andric   case Intrinsic::bswap: {
193*bdd1243dSDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
194*bdd1243dSDimitry Andric     IRBuilder<> IRB(EntryBB);
195*bdd1243dSDimitry Andric     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
196*bdd1243dSDimitry Andric                                       F->getArg(0));
197*bdd1243dSDimitry Andric     IRB.CreateRet(BSwap);
198*bdd1243dSDimitry Andric     IntrinsicLowering IL(M->getDataLayout());
199*bdd1243dSDimitry Andric     IL.LowerIntrinsicCall(BSwap);
200*bdd1243dSDimitry Andric     break;
201*bdd1243dSDimitry Andric   }
202*bdd1243dSDimitry Andric   default:
203*bdd1243dSDimitry Andric     break;
204*bdd1243dSDimitry Andric   }
205*bdd1243dSDimitry Andric   return;
206*bdd1243dSDimitry Andric }
207*bdd1243dSDimitry Andric 
208fcaf7f86SDimitry Andric static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
209fcaf7f86SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
210fcaf7f86SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
211fcaf7f86SDimitry Andric   // function.
212fcaf7f86SDimitry Andric   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
213fcaf7f86SDimitry Andric   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
214fcaf7f86SDimitry Andric   Type *FSHRetTy = FSHFuncTy->getReturnType();
215fcaf7f86SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
216fcaf7f86SDimitry Andric   Function *FSHFunc =
217fcaf7f86SDimitry Andric       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
218fcaf7f86SDimitry Andric 
219fcaf7f86SDimitry Andric   if (!FSHFunc->empty()) {
220fcaf7f86SDimitry Andric     FSHIntrinsic->setCalledFunction(FSHFunc);
221fcaf7f86SDimitry Andric     return;
222fcaf7f86SDimitry Andric   }
223fcaf7f86SDimitry Andric   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
224fcaf7f86SDimitry Andric   IRBuilder<> IRB(RotateBB);
225fcaf7f86SDimitry Andric   Type *Ty = FSHFunc->getReturnType();
226fcaf7f86SDimitry Andric   // Build the actual funnel shift rotate logic.
227fcaf7f86SDimitry Andric   // In the comments, "int" is used interchangeably with "vector of int
228fcaf7f86SDimitry Andric   // elements".
229fcaf7f86SDimitry Andric   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
230fcaf7f86SDimitry Andric   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
231fcaf7f86SDimitry Andric   unsigned BitWidth = IntTy->getIntegerBitWidth();
232fcaf7f86SDimitry Andric   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
233fcaf7f86SDimitry Andric   Value *BitWidthForInsts =
234fcaf7f86SDimitry Andric       VectorTy
235fcaf7f86SDimitry Andric           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
236fcaf7f86SDimitry Andric           : BitWidthConstant;
237fcaf7f86SDimitry Andric   Value *RotateModVal =
238fcaf7f86SDimitry Andric       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
239fcaf7f86SDimitry Andric   Value *FirstShift = nullptr, *SecShift = nullptr;
240fcaf7f86SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
241fcaf7f86SDimitry Andric     // Shift the less significant number right, the "rotate" number of bits
242fcaf7f86SDimitry Andric     // will be 0-filled on the left as a result of this regular shift.
243fcaf7f86SDimitry Andric     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
244fcaf7f86SDimitry Andric   } else {
245fcaf7f86SDimitry Andric     // Shift the more significant number left, the "rotate" number of bits
246fcaf7f86SDimitry Andric     // will be 0-filled on the right as a result of this regular shift.
247fcaf7f86SDimitry Andric     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
248fcaf7f86SDimitry Andric   }
249fcaf7f86SDimitry Andric   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
250fcaf7f86SDimitry Andric   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
251fcaf7f86SDimitry Andric   // Therefore, subtract the "rotate" number from the integer bitsize...
252fcaf7f86SDimitry Andric   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
253fcaf7f86SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
254fcaf7f86SDimitry Andric     // ...and left-shift the more significant int by this number, zero-filling
255fcaf7f86SDimitry Andric     // the LSBs.
256fcaf7f86SDimitry Andric     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
257fcaf7f86SDimitry Andric   } else {
258fcaf7f86SDimitry Andric     // ...and right-shift the less significant int by this number, zero-filling
259fcaf7f86SDimitry Andric     // the MSBs.
260fcaf7f86SDimitry Andric     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
261fcaf7f86SDimitry Andric   }
262fcaf7f86SDimitry Andric   // A simple binary addition of the shifted ints yields the final result.
263fcaf7f86SDimitry Andric   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
264fcaf7f86SDimitry Andric 
265fcaf7f86SDimitry Andric   FSHIntrinsic->setCalledFunction(FSHFunc);
266fcaf7f86SDimitry Andric }
267fcaf7f86SDimitry Andric 
268fcaf7f86SDimitry Andric static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
269fcaf7f86SDimitry Andric   // The function body is already created.
270fcaf7f86SDimitry Andric   if (!UMulFunc->empty())
271fcaf7f86SDimitry Andric     return;
272fcaf7f86SDimitry Andric 
273fcaf7f86SDimitry Andric   BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
274fcaf7f86SDimitry Andric   IRBuilder<> IRB(EntryBB);
275fcaf7f86SDimitry Andric   // Build the actual unsigned multiplication logic with the overflow
276fcaf7f86SDimitry Andric   // indication. Do unsigned multiplication Mul = A * B. Then check
277fcaf7f86SDimitry Andric   // if unsigned division Div = Mul / A is not equal to B. If so,
278fcaf7f86SDimitry Andric   // then overflow has happened.
279fcaf7f86SDimitry Andric   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
280fcaf7f86SDimitry Andric   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
281fcaf7f86SDimitry Andric   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
282fcaf7f86SDimitry Andric 
283fcaf7f86SDimitry Andric   // umul.with.overflow intrinsic return a structure, where the first element
284fcaf7f86SDimitry Andric   // is the multiplication result, and the second is an overflow bit.
285fcaf7f86SDimitry Andric   Type *StructTy = UMulFunc->getReturnType();
286*bdd1243dSDimitry Andric   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
287fcaf7f86SDimitry Andric   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
288fcaf7f86SDimitry Andric   IRB.CreateRet(Res);
289fcaf7f86SDimitry Andric }
290fcaf7f86SDimitry Andric 
291fcaf7f86SDimitry Andric static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
292fcaf7f86SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
293fcaf7f86SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
294fcaf7f86SDimitry Andric   // function.
295fcaf7f86SDimitry Andric   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
296fcaf7f86SDimitry Andric   Type *FSHLRetTy = UMulFuncTy->getReturnType();
297fcaf7f86SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
298fcaf7f86SDimitry Andric   Function *UMulFunc =
299fcaf7f86SDimitry Andric       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
300fcaf7f86SDimitry Andric   buildUMulWithOverflowFunc(M, UMulFunc);
301fcaf7f86SDimitry Andric   UMulIntrinsic->setCalledFunction(UMulFunc);
302fcaf7f86SDimitry Andric }
303fcaf7f86SDimitry Andric 
304fcaf7f86SDimitry Andric static void substituteIntrinsicCalls(Module *M, Function *F) {
305fcaf7f86SDimitry Andric   for (BasicBlock &BB : *F) {
306fcaf7f86SDimitry Andric     for (Instruction &I : BB) {
307fcaf7f86SDimitry Andric       auto Call = dyn_cast<CallInst>(&I);
308fcaf7f86SDimitry Andric       if (!Call)
309fcaf7f86SDimitry Andric         continue;
310fcaf7f86SDimitry Andric       Call->setTailCall(false);
311fcaf7f86SDimitry Andric       Function *CF = Call->getCalledFunction();
312fcaf7f86SDimitry Andric       if (!CF || !CF->isIntrinsic())
313fcaf7f86SDimitry Andric         continue;
314fcaf7f86SDimitry Andric       auto *II = cast<IntrinsicInst>(Call);
315*bdd1243dSDimitry Andric       if (II->getIntrinsicID() == Intrinsic::memset ||
316*bdd1243dSDimitry Andric           II->getIntrinsicID() == Intrinsic::bswap)
317*bdd1243dSDimitry Andric         lowerIntrinsicToFunction(M, II);
318*bdd1243dSDimitry Andric       else if (II->getIntrinsicID() == Intrinsic::fshl ||
319fcaf7f86SDimitry Andric                II->getIntrinsicID() == Intrinsic::fshr)
320fcaf7f86SDimitry Andric         lowerFunnelShifts(M, II);
321fcaf7f86SDimitry Andric       else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
322fcaf7f86SDimitry Andric         lowerUMulWithOverflow(M, II);
323fcaf7f86SDimitry Andric     }
324fcaf7f86SDimitry Andric   }
325fcaf7f86SDimitry Andric }
326fcaf7f86SDimitry Andric 
327fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) {
328fcaf7f86SDimitry Andric   for (Function &F : M)
329fcaf7f86SDimitry Andric     substituteIntrinsicCalls(&M, &F);
330fcaf7f86SDimitry Andric 
331fcaf7f86SDimitry Andric   std::vector<Function *> FuncsWorklist;
332fcaf7f86SDimitry Andric   bool Changed = false;
333fcaf7f86SDimitry Andric   for (auto &F : M)
334fcaf7f86SDimitry Andric     FuncsWorklist.push_back(&F);
335fcaf7f86SDimitry Andric 
336fcaf7f86SDimitry Andric   for (auto *Func : FuncsWorklist) {
337fcaf7f86SDimitry Andric     Function *F = processFunctionSignature(Func);
338fcaf7f86SDimitry Andric 
339fcaf7f86SDimitry Andric     bool CreatedNewF = F != Func;
340fcaf7f86SDimitry Andric 
341fcaf7f86SDimitry Andric     if (Func->isDeclaration()) {
342fcaf7f86SDimitry Andric       Changed |= CreatedNewF;
343fcaf7f86SDimitry Andric       continue;
344fcaf7f86SDimitry Andric     }
345fcaf7f86SDimitry Andric 
346fcaf7f86SDimitry Andric     if (CreatedNewF)
347fcaf7f86SDimitry Andric       Func->eraseFromParent();
348fcaf7f86SDimitry Andric   }
349fcaf7f86SDimitry Andric 
350fcaf7f86SDimitry Andric   return Changed;
351fcaf7f86SDimitry Andric }
352fcaf7f86SDimitry Andric 
353fcaf7f86SDimitry Andric ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
354fcaf7f86SDimitry Andric   return new SPIRVPrepareFunctions();
355fcaf7f86SDimitry Andric }
356