xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
1fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2fcaf7f86SDimitry Andric //
3fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fcaf7f86SDimitry Andric //
7fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
8fcaf7f86SDimitry Andric //
9fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments
10*06c3fb27SDimitry Andric // and/or return value before IRTranslator. Information about the original
11*06c3fb27SDimitry Andric // signatures is stored in metadata. It is used during call lowering to
12*06c3fb27SDimitry Andric // restore correct SPIR-V types of function arguments and return values.
13*06c3fb27SDimitry Andric // This pass also substitutes some llvm intrinsic calls with calls to newly
14*06c3fb27SDimitry Andric // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15fcaf7f86SDimitry Andric //
16fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify
17fcaf7f86SDimitry Andric // GVs/functions.
18fcaf7f86SDimitry Andric //
19fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
20fcaf7f86SDimitry Andric 
21fcaf7f86SDimitry Andric #include "SPIRV.h"
22fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h"
23fcaf7f86SDimitry Andric #include "SPIRVUtils.h"
24bdd1243dSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
25fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h"
26fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
27fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
28fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
29fcaf7f86SDimitry Andric 
30fcaf7f86SDimitry Andric using namespace llvm;
31fcaf7f86SDimitry Andric 
32fcaf7f86SDimitry Andric namespace llvm {
33fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
34fcaf7f86SDimitry Andric }
35fcaf7f86SDimitry Andric 
36fcaf7f86SDimitry Andric namespace {
37fcaf7f86SDimitry Andric 
38fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass {
39*06c3fb27SDimitry Andric   bool substituteIntrinsicCalls(Function *F);
40*06c3fb27SDimitry Andric   Function *removeAggregateTypesFromSignature(Function *F);
41fcaf7f86SDimitry Andric 
42fcaf7f86SDimitry Andric public:
43fcaf7f86SDimitry Andric   static char ID;
44fcaf7f86SDimitry Andric   SPIRVPrepareFunctions() : ModulePass(ID) {
45fcaf7f86SDimitry Andric     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
46fcaf7f86SDimitry Andric   }
47fcaf7f86SDimitry Andric 
48fcaf7f86SDimitry Andric   bool runOnModule(Module &M) override;
49fcaf7f86SDimitry Andric 
50fcaf7f86SDimitry Andric   StringRef getPassName() const override { return "SPIRV prepare functions"; }
51fcaf7f86SDimitry Andric 
52fcaf7f86SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
53fcaf7f86SDimitry Andric     ModulePass::getAnalysisUsage(AU);
54fcaf7f86SDimitry Andric   }
55fcaf7f86SDimitry Andric };
56fcaf7f86SDimitry Andric 
57fcaf7f86SDimitry Andric } // namespace
58fcaf7f86SDimitry Andric 
59fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0;
60fcaf7f86SDimitry Andric 
61fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
62fcaf7f86SDimitry Andric                 "SPIRV prepare functions", false, false)
63fcaf7f86SDimitry Andric 
64*06c3fb27SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
65*06c3fb27SDimitry Andric   Function *IntrinsicFunc = II->getCalledFunction();
66*06c3fb27SDimitry Andric   assert(IntrinsicFunc && "Missing function");
67*06c3fb27SDimitry Andric   std::string FuncName = IntrinsicFunc->getName().str();
68*06c3fb27SDimitry Andric   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
69*06c3fb27SDimitry Andric   FuncName = "spirv." + FuncName;
70*06c3fb27SDimitry Andric   return FuncName;
71*06c3fb27SDimitry Andric }
72*06c3fb27SDimitry Andric 
73*06c3fb27SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy,
74*06c3fb27SDimitry Andric                                      ArrayRef<Type *> ArgTypes,
75*06c3fb27SDimitry Andric                                      StringRef Name) {
76*06c3fb27SDimitry Andric   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
77*06c3fb27SDimitry Andric   Function *F = M->getFunction(Name);
78*06c3fb27SDimitry Andric   if (F && F->getFunctionType() == FT)
79*06c3fb27SDimitry Andric     return F;
80*06c3fb27SDimitry Andric   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
81*06c3fb27SDimitry Andric   if (F)
82*06c3fb27SDimitry Andric     NewF->setDSOLocal(F->isDSOLocal());
83*06c3fb27SDimitry Andric   NewF->setCallingConv(CallingConv::SPIR_FUNC);
84*06c3fb27SDimitry Andric   return NewF;
85*06c3fb27SDimitry Andric }
86*06c3fb27SDimitry Andric 
87*06c3fb27SDimitry Andric static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
88*06c3fb27SDimitry Andric   // For @llvm.memset.* intrinsic cases with constant value and length arguments
89*06c3fb27SDimitry Andric   // are emulated via "storing" a constant array to the destination. For other
90*06c3fb27SDimitry Andric   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
91*06c3fb27SDimitry Andric   // intrinsic to a loop via expandMemSetAsLoop().
92*06c3fb27SDimitry Andric   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
93*06c3fb27SDimitry Andric     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
94*06c3fb27SDimitry Andric       return false; // It is handled later using OpCopyMemorySized.
95*06c3fb27SDimitry Andric 
96*06c3fb27SDimitry Andric   Module *M = Intrinsic->getModule();
97*06c3fb27SDimitry Andric   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
98*06c3fb27SDimitry Andric   if (Intrinsic->isVolatile())
99*06c3fb27SDimitry Andric     FuncName += ".volatile";
100*06c3fb27SDimitry Andric   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
101*06c3fb27SDimitry Andric   Function *F = M->getFunction(FuncName);
102*06c3fb27SDimitry Andric   if (F) {
103*06c3fb27SDimitry Andric     Intrinsic->setCalledFunction(F);
104*06c3fb27SDimitry Andric     return true;
105*06c3fb27SDimitry Andric   }
106*06c3fb27SDimitry Andric   // TODO copy arguments attributes: nocapture writeonly.
107*06c3fb27SDimitry Andric   FunctionCallee FC =
108*06c3fb27SDimitry Andric       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
109*06c3fb27SDimitry Andric   auto IntrinsicID = Intrinsic->getIntrinsicID();
110*06c3fb27SDimitry Andric   Intrinsic->setCalledFunction(FC);
111*06c3fb27SDimitry Andric 
112*06c3fb27SDimitry Andric   F = dyn_cast<Function>(FC.getCallee());
113*06c3fb27SDimitry Andric   assert(F && "Callee must be a function");
114*06c3fb27SDimitry Andric 
115*06c3fb27SDimitry Andric   switch (IntrinsicID) {
116*06c3fb27SDimitry Andric   case Intrinsic::memset: {
117*06c3fb27SDimitry Andric     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
118*06c3fb27SDimitry Andric     Argument *Dest = F->getArg(0);
119*06c3fb27SDimitry Andric     Argument *Val = F->getArg(1);
120*06c3fb27SDimitry Andric     Argument *Len = F->getArg(2);
121*06c3fb27SDimitry Andric     Argument *IsVolatile = F->getArg(3);
122*06c3fb27SDimitry Andric     Dest->setName("dest");
123*06c3fb27SDimitry Andric     Val->setName("val");
124*06c3fb27SDimitry Andric     Len->setName("len");
125*06c3fb27SDimitry Andric     IsVolatile->setName("isvolatile");
126*06c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
127*06c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
128*06c3fb27SDimitry Andric     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
129*06c3fb27SDimitry Andric                                     MSI->isVolatile());
130*06c3fb27SDimitry Andric     IRB.CreateRetVoid();
131*06c3fb27SDimitry Andric     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
132*06c3fb27SDimitry Andric     MemSet->eraseFromParent();
133*06c3fb27SDimitry Andric     break;
134*06c3fb27SDimitry Andric   }
135*06c3fb27SDimitry Andric   case Intrinsic::bswap: {
136*06c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
137*06c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
138*06c3fb27SDimitry Andric     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
139*06c3fb27SDimitry Andric                                       F->getArg(0));
140*06c3fb27SDimitry Andric     IRB.CreateRet(BSwap);
141*06c3fb27SDimitry Andric     IntrinsicLowering IL(M->getDataLayout());
142*06c3fb27SDimitry Andric     IL.LowerIntrinsicCall(BSwap);
143*06c3fb27SDimitry Andric     break;
144*06c3fb27SDimitry Andric   }
145*06c3fb27SDimitry Andric   default:
146*06c3fb27SDimitry Andric     break;
147*06c3fb27SDimitry Andric   }
148*06c3fb27SDimitry Andric   return true;
149*06c3fb27SDimitry Andric }
150*06c3fb27SDimitry Andric 
151*06c3fb27SDimitry Andric static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
152*06c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
153*06c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
154*06c3fb27SDimitry Andric   // function.
155*06c3fb27SDimitry Andric   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
156*06c3fb27SDimitry Andric   Module *M = FSHIntrinsic->getModule();
157*06c3fb27SDimitry Andric   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
158*06c3fb27SDimitry Andric   Type *FSHRetTy = FSHFuncTy->getReturnType();
159*06c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
160*06c3fb27SDimitry Andric   Function *FSHFunc =
161*06c3fb27SDimitry Andric       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
162*06c3fb27SDimitry Andric 
163*06c3fb27SDimitry Andric   if (!FSHFunc->empty()) {
164*06c3fb27SDimitry Andric     FSHIntrinsic->setCalledFunction(FSHFunc);
165*06c3fb27SDimitry Andric     return;
166*06c3fb27SDimitry Andric   }
167*06c3fb27SDimitry Andric   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
168*06c3fb27SDimitry Andric   IRBuilder<> IRB(RotateBB);
169*06c3fb27SDimitry Andric   Type *Ty = FSHFunc->getReturnType();
170*06c3fb27SDimitry Andric   // Build the actual funnel shift rotate logic.
171*06c3fb27SDimitry Andric   // In the comments, "int" is used interchangeably with "vector of int
172*06c3fb27SDimitry Andric   // elements".
173*06c3fb27SDimitry Andric   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
174*06c3fb27SDimitry Andric   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
175*06c3fb27SDimitry Andric   unsigned BitWidth = IntTy->getIntegerBitWidth();
176*06c3fb27SDimitry Andric   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
177*06c3fb27SDimitry Andric   Value *BitWidthForInsts =
178*06c3fb27SDimitry Andric       VectorTy
179*06c3fb27SDimitry Andric           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
180*06c3fb27SDimitry Andric           : BitWidthConstant;
181*06c3fb27SDimitry Andric   Value *RotateModVal =
182*06c3fb27SDimitry Andric       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
183*06c3fb27SDimitry Andric   Value *FirstShift = nullptr, *SecShift = nullptr;
184*06c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
185*06c3fb27SDimitry Andric     // Shift the less significant number right, the "rotate" number of bits
186*06c3fb27SDimitry Andric     // will be 0-filled on the left as a result of this regular shift.
187*06c3fb27SDimitry Andric     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
188*06c3fb27SDimitry Andric   } else {
189*06c3fb27SDimitry Andric     // Shift the more significant number left, the "rotate" number of bits
190*06c3fb27SDimitry Andric     // will be 0-filled on the right as a result of this regular shift.
191*06c3fb27SDimitry Andric     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
192*06c3fb27SDimitry Andric   }
193*06c3fb27SDimitry Andric   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
194*06c3fb27SDimitry Andric   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
195*06c3fb27SDimitry Andric   // Therefore, subtract the "rotate" number from the integer bitsize...
196*06c3fb27SDimitry Andric   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
197*06c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
198*06c3fb27SDimitry Andric     // ...and left-shift the more significant int by this number, zero-filling
199*06c3fb27SDimitry Andric     // the LSBs.
200*06c3fb27SDimitry Andric     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
201*06c3fb27SDimitry Andric   } else {
202*06c3fb27SDimitry Andric     // ...and right-shift the less significant int by this number, zero-filling
203*06c3fb27SDimitry Andric     // the MSBs.
204*06c3fb27SDimitry Andric     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
205*06c3fb27SDimitry Andric   }
206*06c3fb27SDimitry Andric   // A simple binary addition of the shifted ints yields the final result.
207*06c3fb27SDimitry Andric   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
208*06c3fb27SDimitry Andric 
209*06c3fb27SDimitry Andric   FSHIntrinsic->setCalledFunction(FSHFunc);
210*06c3fb27SDimitry Andric }
211*06c3fb27SDimitry Andric 
212*06c3fb27SDimitry Andric static void buildUMulWithOverflowFunc(Function *UMulFunc) {
213*06c3fb27SDimitry Andric   // The function body is already created.
214*06c3fb27SDimitry Andric   if (!UMulFunc->empty())
215*06c3fb27SDimitry Andric     return;
216*06c3fb27SDimitry Andric 
217*06c3fb27SDimitry Andric   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
218*06c3fb27SDimitry Andric                                            "entry", UMulFunc);
219*06c3fb27SDimitry Andric   IRBuilder<> IRB(EntryBB);
220*06c3fb27SDimitry Andric   // Build the actual unsigned multiplication logic with the overflow
221*06c3fb27SDimitry Andric   // indication. Do unsigned multiplication Mul = A * B. Then check
222*06c3fb27SDimitry Andric   // if unsigned division Div = Mul / A is not equal to B. If so,
223*06c3fb27SDimitry Andric   // then overflow has happened.
224*06c3fb27SDimitry Andric   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
225*06c3fb27SDimitry Andric   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
226*06c3fb27SDimitry Andric   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
227*06c3fb27SDimitry Andric 
228*06c3fb27SDimitry Andric   // umul.with.overflow intrinsic return a structure, where the first element
229*06c3fb27SDimitry Andric   // is the multiplication result, and the second is an overflow bit.
230*06c3fb27SDimitry Andric   Type *StructTy = UMulFunc->getReturnType();
231*06c3fb27SDimitry Andric   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
232*06c3fb27SDimitry Andric   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
233*06c3fb27SDimitry Andric   IRB.CreateRet(Res);
234*06c3fb27SDimitry Andric }
235*06c3fb27SDimitry Andric 
236*06c3fb27SDimitry Andric static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
237*06c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
238*06c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
239*06c3fb27SDimitry Andric   // function.
240*06c3fb27SDimitry Andric   Module *M = UMulIntrinsic->getModule();
241*06c3fb27SDimitry Andric   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
242*06c3fb27SDimitry Andric   Type *FSHLRetTy = UMulFuncTy->getReturnType();
243*06c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
244*06c3fb27SDimitry Andric   Function *UMulFunc =
245*06c3fb27SDimitry Andric       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
246*06c3fb27SDimitry Andric   buildUMulWithOverflowFunc(UMulFunc);
247*06c3fb27SDimitry Andric   UMulIntrinsic->setCalledFunction(UMulFunc);
248*06c3fb27SDimitry Andric }
249*06c3fb27SDimitry Andric 
250*06c3fb27SDimitry Andric // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
251*06c3fb27SDimitry Andric // or calls to proper generated functions. Returns True if F was modified.
252*06c3fb27SDimitry Andric bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
253*06c3fb27SDimitry Andric   bool Changed = false;
254*06c3fb27SDimitry Andric   for (BasicBlock &BB : *F) {
255*06c3fb27SDimitry Andric     for (Instruction &I : BB) {
256*06c3fb27SDimitry Andric       auto Call = dyn_cast<CallInst>(&I);
257*06c3fb27SDimitry Andric       if (!Call)
258*06c3fb27SDimitry Andric         continue;
259*06c3fb27SDimitry Andric       Function *CF = Call->getCalledFunction();
260*06c3fb27SDimitry Andric       if (!CF || !CF->isIntrinsic())
261*06c3fb27SDimitry Andric         continue;
262*06c3fb27SDimitry Andric       auto *II = cast<IntrinsicInst>(Call);
263*06c3fb27SDimitry Andric       if (II->getIntrinsicID() == Intrinsic::memset ||
264*06c3fb27SDimitry Andric           II->getIntrinsicID() == Intrinsic::bswap)
265*06c3fb27SDimitry Andric         Changed |= lowerIntrinsicToFunction(II);
266*06c3fb27SDimitry Andric       else if (II->getIntrinsicID() == Intrinsic::fshl ||
267*06c3fb27SDimitry Andric                II->getIntrinsicID() == Intrinsic::fshr) {
268*06c3fb27SDimitry Andric         lowerFunnelShifts(II);
269*06c3fb27SDimitry Andric         Changed = true;
270*06c3fb27SDimitry Andric       } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
271*06c3fb27SDimitry Andric         lowerUMulWithOverflow(II);
272*06c3fb27SDimitry Andric         Changed = true;
273*06c3fb27SDimitry Andric       }
274*06c3fb27SDimitry Andric     }
275*06c3fb27SDimitry Andric   }
276*06c3fb27SDimitry Andric   return Changed;
277*06c3fb27SDimitry Andric }
278*06c3fb27SDimitry Andric 
279*06c3fb27SDimitry Andric // Returns F if aggregate argument/return types are not present or cloned F
280*06c3fb27SDimitry Andric // function with the types replaced by i32 types. The change in types is
281*06c3fb27SDimitry Andric // noted in 'spv.cloned_funcs' metadata for later restoration.
282*06c3fb27SDimitry Andric Function *
283*06c3fb27SDimitry Andric SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
284fcaf7f86SDimitry Andric   IRBuilder<> B(F->getContext());
285fcaf7f86SDimitry Andric 
286fcaf7f86SDimitry Andric   bool IsRetAggr = F->getReturnType()->isAggregateType();
287fcaf7f86SDimitry Andric   bool HasAggrArg =
288fcaf7f86SDimitry Andric       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
289fcaf7f86SDimitry Andric         return Arg.getType()->isAggregateType();
290fcaf7f86SDimitry Andric       });
291fcaf7f86SDimitry Andric   bool DoClone = IsRetAggr || HasAggrArg;
292fcaf7f86SDimitry Andric   if (!DoClone)
293fcaf7f86SDimitry Andric     return F;
294fcaf7f86SDimitry Andric   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
295fcaf7f86SDimitry Andric   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
296fcaf7f86SDimitry Andric   if (IsRetAggr)
297fcaf7f86SDimitry Andric     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
298fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
299fcaf7f86SDimitry Andric   for (const auto &Arg : F->args()) {
300fcaf7f86SDimitry Andric     if (Arg.getType()->isAggregateType()) {
301fcaf7f86SDimitry Andric       ArgTypes.push_back(B.getInt32Ty());
302fcaf7f86SDimitry Andric       ChangedTypes.push_back(
303fcaf7f86SDimitry Andric           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
304fcaf7f86SDimitry Andric     } else
305fcaf7f86SDimitry Andric       ArgTypes.push_back(Arg.getType());
306fcaf7f86SDimitry Andric   }
307fcaf7f86SDimitry Andric   FunctionType *NewFTy =
308fcaf7f86SDimitry Andric       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
309fcaf7f86SDimitry Andric   Function *NewF =
310fcaf7f86SDimitry Andric       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
311fcaf7f86SDimitry Andric 
312fcaf7f86SDimitry Andric   ValueToValueMapTy VMap;
313fcaf7f86SDimitry Andric   auto NewFArgIt = NewF->arg_begin();
314fcaf7f86SDimitry Andric   for (auto &Arg : F->args()) {
315fcaf7f86SDimitry Andric     StringRef ArgName = Arg.getName();
316fcaf7f86SDimitry Andric     NewFArgIt->setName(ArgName);
317fcaf7f86SDimitry Andric     VMap[&Arg] = &(*NewFArgIt++);
318fcaf7f86SDimitry Andric   }
319fcaf7f86SDimitry Andric   SmallVector<ReturnInst *, 8> Returns;
320fcaf7f86SDimitry Andric 
321fcaf7f86SDimitry Andric   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
322fcaf7f86SDimitry Andric                     Returns);
323fcaf7f86SDimitry Andric   NewF->takeName(F);
324fcaf7f86SDimitry Andric 
325fcaf7f86SDimitry Andric   NamedMDNode *FuncMD =
326fcaf7f86SDimitry Andric       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
327fcaf7f86SDimitry Andric   SmallVector<Metadata *, 2> MDArgs;
328fcaf7f86SDimitry Andric   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
329fcaf7f86SDimitry Andric   for (auto &ChangedTyP : ChangedTypes)
330fcaf7f86SDimitry Andric     MDArgs.push_back(MDNode::get(
331fcaf7f86SDimitry Andric         B.getContext(),
332fcaf7f86SDimitry Andric         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
333fcaf7f86SDimitry Andric          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
334fcaf7f86SDimitry Andric   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
335fcaf7f86SDimitry Andric   FuncMD->addOperand(ThisFuncMD);
336fcaf7f86SDimitry Andric 
337fcaf7f86SDimitry Andric   for (auto *U : make_early_inc_range(F->users())) {
338fcaf7f86SDimitry Andric     if (auto *CI = dyn_cast<CallInst>(U))
339fcaf7f86SDimitry Andric       CI->mutateFunctionType(NewF->getFunctionType());
340fcaf7f86SDimitry Andric     U->replaceUsesOfWith(F, NewF);
341fcaf7f86SDimitry Andric   }
342fcaf7f86SDimitry Andric   return NewF;
343fcaf7f86SDimitry Andric }
344fcaf7f86SDimitry Andric 
345fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) {
346*06c3fb27SDimitry Andric   bool Changed = false;
347fcaf7f86SDimitry Andric   for (Function &F : M)
348*06c3fb27SDimitry Andric     Changed |= substituteIntrinsicCalls(&F);
349fcaf7f86SDimitry Andric 
350fcaf7f86SDimitry Andric   std::vector<Function *> FuncsWorklist;
351fcaf7f86SDimitry Andric   for (auto &F : M)
352fcaf7f86SDimitry Andric     FuncsWorklist.push_back(&F);
353fcaf7f86SDimitry Andric 
354*06c3fb27SDimitry Andric   for (auto *F : FuncsWorklist) {
355*06c3fb27SDimitry Andric     Function *NewF = removeAggregateTypesFromSignature(F);
356fcaf7f86SDimitry Andric 
357*06c3fb27SDimitry Andric     if (NewF != F) {
358*06c3fb27SDimitry Andric       F->eraseFromParent();
359*06c3fb27SDimitry Andric       Changed = true;
360fcaf7f86SDimitry Andric     }
361fcaf7f86SDimitry Andric   }
362fcaf7f86SDimitry Andric   return Changed;
363fcaf7f86SDimitry Andric }
364fcaf7f86SDimitry Andric 
365fcaf7f86SDimitry Andric ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
366fcaf7f86SDimitry Andric   return new SPIRVPrepareFunctions();
367fcaf7f86SDimitry Andric }
368