xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision fcaf7f8644a9988098ac6be2165bce3ea4786e91)
1*fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2*fcaf7f86SDimitry Andric //
3*fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fcaf7f86SDimitry Andric //
7*fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
8*fcaf7f86SDimitry Andric //
9*fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments
10*fcaf7f86SDimitry Andric // and/or return value. Also it substitutes some llvm intrinsic calls by
11*fcaf7f86SDimitry Andric // function calls, generating these functions as the translator does.
12*fcaf7f86SDimitry Andric //
13*fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify
14*fcaf7f86SDimitry Andric // GVs/functions.
15*fcaf7f86SDimitry Andric //
16*fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
17*fcaf7f86SDimitry Andric 
18*fcaf7f86SDimitry Andric #include "SPIRV.h"
19*fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h"
20*fcaf7f86SDimitry Andric #include "SPIRVUtils.h"
21*fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h"
22*fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
23*fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
24*fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
25*fcaf7f86SDimitry Andric 
26*fcaf7f86SDimitry Andric using namespace llvm;
27*fcaf7f86SDimitry Andric 
28*fcaf7f86SDimitry Andric namespace llvm {
29*fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
30*fcaf7f86SDimitry Andric }
31*fcaf7f86SDimitry Andric 
32*fcaf7f86SDimitry Andric namespace {
33*fcaf7f86SDimitry Andric 
34*fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass {
35*fcaf7f86SDimitry Andric   Function *processFunctionSignature(Function *F);
36*fcaf7f86SDimitry Andric 
37*fcaf7f86SDimitry Andric public:
38*fcaf7f86SDimitry Andric   static char ID;
39*fcaf7f86SDimitry Andric   SPIRVPrepareFunctions() : ModulePass(ID) {
40*fcaf7f86SDimitry Andric     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
41*fcaf7f86SDimitry Andric   }
42*fcaf7f86SDimitry Andric 
43*fcaf7f86SDimitry Andric   bool runOnModule(Module &M) override;
44*fcaf7f86SDimitry Andric 
45*fcaf7f86SDimitry Andric   StringRef getPassName() const override { return "SPIRV prepare functions"; }
46*fcaf7f86SDimitry Andric 
47*fcaf7f86SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
48*fcaf7f86SDimitry Andric     ModulePass::getAnalysisUsage(AU);
49*fcaf7f86SDimitry Andric   }
50*fcaf7f86SDimitry Andric };
51*fcaf7f86SDimitry Andric 
52*fcaf7f86SDimitry Andric } // namespace
53*fcaf7f86SDimitry Andric 
54*fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0;
55*fcaf7f86SDimitry Andric 
56*fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
57*fcaf7f86SDimitry Andric                 "SPIRV prepare functions", false, false)
58*fcaf7f86SDimitry Andric 
59*fcaf7f86SDimitry Andric Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
60*fcaf7f86SDimitry Andric   IRBuilder<> B(F->getContext());
61*fcaf7f86SDimitry Andric 
62*fcaf7f86SDimitry Andric   bool IsRetAggr = F->getReturnType()->isAggregateType();
63*fcaf7f86SDimitry Andric   bool HasAggrArg =
64*fcaf7f86SDimitry Andric       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
65*fcaf7f86SDimitry Andric         return Arg.getType()->isAggregateType();
66*fcaf7f86SDimitry Andric       });
67*fcaf7f86SDimitry Andric   bool DoClone = IsRetAggr || HasAggrArg;
68*fcaf7f86SDimitry Andric   if (!DoClone)
69*fcaf7f86SDimitry Andric     return F;
70*fcaf7f86SDimitry Andric   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
71*fcaf7f86SDimitry Andric   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
72*fcaf7f86SDimitry Andric   if (IsRetAggr)
73*fcaf7f86SDimitry Andric     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
74*fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
75*fcaf7f86SDimitry Andric   for (const auto &Arg : F->args()) {
76*fcaf7f86SDimitry Andric     if (Arg.getType()->isAggregateType()) {
77*fcaf7f86SDimitry Andric       ArgTypes.push_back(B.getInt32Ty());
78*fcaf7f86SDimitry Andric       ChangedTypes.push_back(
79*fcaf7f86SDimitry Andric           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
80*fcaf7f86SDimitry Andric     } else
81*fcaf7f86SDimitry Andric       ArgTypes.push_back(Arg.getType());
82*fcaf7f86SDimitry Andric   }
83*fcaf7f86SDimitry Andric   FunctionType *NewFTy =
84*fcaf7f86SDimitry Andric       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
85*fcaf7f86SDimitry Andric   Function *NewF =
86*fcaf7f86SDimitry Andric       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
87*fcaf7f86SDimitry Andric 
88*fcaf7f86SDimitry Andric   ValueToValueMapTy VMap;
89*fcaf7f86SDimitry Andric   auto NewFArgIt = NewF->arg_begin();
90*fcaf7f86SDimitry Andric   for (auto &Arg : F->args()) {
91*fcaf7f86SDimitry Andric     StringRef ArgName = Arg.getName();
92*fcaf7f86SDimitry Andric     NewFArgIt->setName(ArgName);
93*fcaf7f86SDimitry Andric     VMap[&Arg] = &(*NewFArgIt++);
94*fcaf7f86SDimitry Andric   }
95*fcaf7f86SDimitry Andric   SmallVector<ReturnInst *, 8> Returns;
96*fcaf7f86SDimitry Andric 
97*fcaf7f86SDimitry Andric   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
98*fcaf7f86SDimitry Andric                     Returns);
99*fcaf7f86SDimitry Andric   NewF->takeName(F);
100*fcaf7f86SDimitry Andric 
101*fcaf7f86SDimitry Andric   NamedMDNode *FuncMD =
102*fcaf7f86SDimitry Andric       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
103*fcaf7f86SDimitry Andric   SmallVector<Metadata *, 2> MDArgs;
104*fcaf7f86SDimitry Andric   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
105*fcaf7f86SDimitry Andric   for (auto &ChangedTyP : ChangedTypes)
106*fcaf7f86SDimitry Andric     MDArgs.push_back(MDNode::get(
107*fcaf7f86SDimitry Andric         B.getContext(),
108*fcaf7f86SDimitry Andric         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
109*fcaf7f86SDimitry Andric          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
110*fcaf7f86SDimitry Andric   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
111*fcaf7f86SDimitry Andric   FuncMD->addOperand(ThisFuncMD);
112*fcaf7f86SDimitry Andric 
113*fcaf7f86SDimitry Andric   for (auto *U : make_early_inc_range(F->users())) {
114*fcaf7f86SDimitry Andric     if (auto *CI = dyn_cast<CallInst>(U))
115*fcaf7f86SDimitry Andric       CI->mutateFunctionType(NewF->getFunctionType());
116*fcaf7f86SDimitry Andric     U->replaceUsesOfWith(F, NewF);
117*fcaf7f86SDimitry Andric   }
118*fcaf7f86SDimitry Andric   return NewF;
119*fcaf7f86SDimitry Andric }
120*fcaf7f86SDimitry Andric 
121*fcaf7f86SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
122*fcaf7f86SDimitry Andric   Function *IntrinsicFunc = II->getCalledFunction();
123*fcaf7f86SDimitry Andric   assert(IntrinsicFunc && "Missing function");
124*fcaf7f86SDimitry Andric   std::string FuncName = IntrinsicFunc->getName().str();
125*fcaf7f86SDimitry Andric   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
126*fcaf7f86SDimitry Andric   FuncName = "spirv." + FuncName;
127*fcaf7f86SDimitry Andric   return FuncName;
128*fcaf7f86SDimitry Andric }
129*fcaf7f86SDimitry Andric 
130*fcaf7f86SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy,
131*fcaf7f86SDimitry Andric                                      ArrayRef<Type *> ArgTypes,
132*fcaf7f86SDimitry Andric                                      StringRef Name) {
133*fcaf7f86SDimitry Andric   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
134*fcaf7f86SDimitry Andric   Function *F = M->getFunction(Name);
135*fcaf7f86SDimitry Andric   if (F && F->getFunctionType() == FT)
136*fcaf7f86SDimitry Andric     return F;
137*fcaf7f86SDimitry Andric   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
138*fcaf7f86SDimitry Andric   if (F)
139*fcaf7f86SDimitry Andric     NewF->setDSOLocal(F->isDSOLocal());
140*fcaf7f86SDimitry Andric   NewF->setCallingConv(CallingConv::SPIR_FUNC);
141*fcaf7f86SDimitry Andric   return NewF;
142*fcaf7f86SDimitry Andric }
143*fcaf7f86SDimitry Andric 
144*fcaf7f86SDimitry Andric static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
145*fcaf7f86SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
146*fcaf7f86SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
147*fcaf7f86SDimitry Andric   // function.
148*fcaf7f86SDimitry Andric   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
149*fcaf7f86SDimitry Andric   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
150*fcaf7f86SDimitry Andric   Type *FSHRetTy = FSHFuncTy->getReturnType();
151*fcaf7f86SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
152*fcaf7f86SDimitry Andric   Function *FSHFunc =
153*fcaf7f86SDimitry Andric       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
154*fcaf7f86SDimitry Andric 
155*fcaf7f86SDimitry Andric   if (!FSHFunc->empty()) {
156*fcaf7f86SDimitry Andric     FSHIntrinsic->setCalledFunction(FSHFunc);
157*fcaf7f86SDimitry Andric     return;
158*fcaf7f86SDimitry Andric   }
159*fcaf7f86SDimitry Andric   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
160*fcaf7f86SDimitry Andric   IRBuilder<> IRB(RotateBB);
161*fcaf7f86SDimitry Andric   Type *Ty = FSHFunc->getReturnType();
162*fcaf7f86SDimitry Andric   // Build the actual funnel shift rotate logic.
163*fcaf7f86SDimitry Andric   // In the comments, "int" is used interchangeably with "vector of int
164*fcaf7f86SDimitry Andric   // elements".
165*fcaf7f86SDimitry Andric   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
166*fcaf7f86SDimitry Andric   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
167*fcaf7f86SDimitry Andric   unsigned BitWidth = IntTy->getIntegerBitWidth();
168*fcaf7f86SDimitry Andric   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
169*fcaf7f86SDimitry Andric   Value *BitWidthForInsts =
170*fcaf7f86SDimitry Andric       VectorTy
171*fcaf7f86SDimitry Andric           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
172*fcaf7f86SDimitry Andric           : BitWidthConstant;
173*fcaf7f86SDimitry Andric   Value *RotateModVal =
174*fcaf7f86SDimitry Andric       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
175*fcaf7f86SDimitry Andric   Value *FirstShift = nullptr, *SecShift = nullptr;
176*fcaf7f86SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
177*fcaf7f86SDimitry Andric     // Shift the less significant number right, the "rotate" number of bits
178*fcaf7f86SDimitry Andric     // will be 0-filled on the left as a result of this regular shift.
179*fcaf7f86SDimitry Andric     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
180*fcaf7f86SDimitry Andric   } else {
181*fcaf7f86SDimitry Andric     // Shift the more significant number left, the "rotate" number of bits
182*fcaf7f86SDimitry Andric     // will be 0-filled on the right as a result of this regular shift.
183*fcaf7f86SDimitry Andric     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
184*fcaf7f86SDimitry Andric   }
185*fcaf7f86SDimitry Andric   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
186*fcaf7f86SDimitry Andric   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
187*fcaf7f86SDimitry Andric   // Therefore, subtract the "rotate" number from the integer bitsize...
188*fcaf7f86SDimitry Andric   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
189*fcaf7f86SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
190*fcaf7f86SDimitry Andric     // ...and left-shift the more significant int by this number, zero-filling
191*fcaf7f86SDimitry Andric     // the LSBs.
192*fcaf7f86SDimitry Andric     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
193*fcaf7f86SDimitry Andric   } else {
194*fcaf7f86SDimitry Andric     // ...and right-shift the less significant int by this number, zero-filling
195*fcaf7f86SDimitry Andric     // the MSBs.
196*fcaf7f86SDimitry Andric     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
197*fcaf7f86SDimitry Andric   }
198*fcaf7f86SDimitry Andric   // A simple binary addition of the shifted ints yields the final result.
199*fcaf7f86SDimitry Andric   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
200*fcaf7f86SDimitry Andric 
201*fcaf7f86SDimitry Andric   FSHIntrinsic->setCalledFunction(FSHFunc);
202*fcaf7f86SDimitry Andric }
203*fcaf7f86SDimitry Andric 
204*fcaf7f86SDimitry Andric static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
205*fcaf7f86SDimitry Andric   // The function body is already created.
206*fcaf7f86SDimitry Andric   if (!UMulFunc->empty())
207*fcaf7f86SDimitry Andric     return;
208*fcaf7f86SDimitry Andric 
209*fcaf7f86SDimitry Andric   BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
210*fcaf7f86SDimitry Andric   IRBuilder<> IRB(EntryBB);
211*fcaf7f86SDimitry Andric   // Build the actual unsigned multiplication logic with the overflow
212*fcaf7f86SDimitry Andric   // indication. Do unsigned multiplication Mul = A * B. Then check
213*fcaf7f86SDimitry Andric   // if unsigned division Div = Mul / A is not equal to B. If so,
214*fcaf7f86SDimitry Andric   // then overflow has happened.
215*fcaf7f86SDimitry Andric   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
216*fcaf7f86SDimitry Andric   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
217*fcaf7f86SDimitry Andric   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
218*fcaf7f86SDimitry Andric 
219*fcaf7f86SDimitry Andric   // umul.with.overflow intrinsic return a structure, where the first element
220*fcaf7f86SDimitry Andric   // is the multiplication result, and the second is an overflow bit.
221*fcaf7f86SDimitry Andric   Type *StructTy = UMulFunc->getReturnType();
222*fcaf7f86SDimitry Andric   Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0});
223*fcaf7f86SDimitry Andric   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
224*fcaf7f86SDimitry Andric   IRB.CreateRet(Res);
225*fcaf7f86SDimitry Andric }
226*fcaf7f86SDimitry Andric 
227*fcaf7f86SDimitry Andric static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
228*fcaf7f86SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
229*fcaf7f86SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
230*fcaf7f86SDimitry Andric   // function.
231*fcaf7f86SDimitry Andric   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
232*fcaf7f86SDimitry Andric   Type *FSHLRetTy = UMulFuncTy->getReturnType();
233*fcaf7f86SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
234*fcaf7f86SDimitry Andric   Function *UMulFunc =
235*fcaf7f86SDimitry Andric       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
236*fcaf7f86SDimitry Andric   buildUMulWithOverflowFunc(M, UMulFunc);
237*fcaf7f86SDimitry Andric   UMulIntrinsic->setCalledFunction(UMulFunc);
238*fcaf7f86SDimitry Andric }
239*fcaf7f86SDimitry Andric 
240*fcaf7f86SDimitry Andric static void substituteIntrinsicCalls(Module *M, Function *F) {
241*fcaf7f86SDimitry Andric   for (BasicBlock &BB : *F) {
242*fcaf7f86SDimitry Andric     for (Instruction &I : BB) {
243*fcaf7f86SDimitry Andric       auto Call = dyn_cast<CallInst>(&I);
244*fcaf7f86SDimitry Andric       if (!Call)
245*fcaf7f86SDimitry Andric         continue;
246*fcaf7f86SDimitry Andric       Call->setTailCall(false);
247*fcaf7f86SDimitry Andric       Function *CF = Call->getCalledFunction();
248*fcaf7f86SDimitry Andric       if (!CF || !CF->isIntrinsic())
249*fcaf7f86SDimitry Andric         continue;
250*fcaf7f86SDimitry Andric       auto *II = cast<IntrinsicInst>(Call);
251*fcaf7f86SDimitry Andric       if (II->getIntrinsicID() == Intrinsic::fshl ||
252*fcaf7f86SDimitry Andric           II->getIntrinsicID() == Intrinsic::fshr)
253*fcaf7f86SDimitry Andric         lowerFunnelShifts(M, II);
254*fcaf7f86SDimitry Andric       else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
255*fcaf7f86SDimitry Andric         lowerUMulWithOverflow(M, II);
256*fcaf7f86SDimitry Andric     }
257*fcaf7f86SDimitry Andric   }
258*fcaf7f86SDimitry Andric }
259*fcaf7f86SDimitry Andric 
260*fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) {
261*fcaf7f86SDimitry Andric   for (Function &F : M)
262*fcaf7f86SDimitry Andric     substituteIntrinsicCalls(&M, &F);
263*fcaf7f86SDimitry Andric 
264*fcaf7f86SDimitry Andric   std::vector<Function *> FuncsWorklist;
265*fcaf7f86SDimitry Andric   bool Changed = false;
266*fcaf7f86SDimitry Andric   for (auto &F : M)
267*fcaf7f86SDimitry Andric     FuncsWorklist.push_back(&F);
268*fcaf7f86SDimitry Andric 
269*fcaf7f86SDimitry Andric   for (auto *Func : FuncsWorklist) {
270*fcaf7f86SDimitry Andric     Function *F = processFunctionSignature(Func);
271*fcaf7f86SDimitry Andric 
272*fcaf7f86SDimitry Andric     bool CreatedNewF = F != Func;
273*fcaf7f86SDimitry Andric 
274*fcaf7f86SDimitry Andric     if (Func->isDeclaration()) {
275*fcaf7f86SDimitry Andric       Changed |= CreatedNewF;
276*fcaf7f86SDimitry Andric       continue;
277*fcaf7f86SDimitry Andric     }
278*fcaf7f86SDimitry Andric 
279*fcaf7f86SDimitry Andric     if (CreatedNewF)
280*fcaf7f86SDimitry Andric       Func->eraseFromParent();
281*fcaf7f86SDimitry Andric   }
282*fcaf7f86SDimitry Andric 
283*fcaf7f86SDimitry Andric   return Changed;
284*fcaf7f86SDimitry Andric }
285*fcaf7f86SDimitry Andric 
286*fcaf7f86SDimitry Andric ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
287*fcaf7f86SDimitry Andric   return new SPIRVPrepareFunctions();
288*fcaf7f86SDimitry Andric }
289