1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value. Also it substitutes some llvm intrinsic calls by
11 // function calls, generating these functions as the translator does.
12 //
13 // NOTE: this pass is a module-level one due to the necessity to modify
14 // GVs/functions.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "SPIRV.h"
19 #include "SPIRVTargetMachine.h"
20 #include "SPIRVUtils.h"
21 #include "llvm/CodeGen/IntrinsicLowering.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/Transforms/Utils/Cloning.h"
25 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
26
27 using namespace llvm;
28
29 namespace llvm {
30 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
31 }
32
33 namespace {
34
35 class SPIRVPrepareFunctions : public ModulePass {
36 Function *processFunctionSignature(Function *F);
37
38 public:
39 static char ID;
SPIRVPrepareFunctions()40 SPIRVPrepareFunctions() : ModulePass(ID) {
41 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
42 }
43
44 bool runOnModule(Module &M) override;
45
getPassName() const46 StringRef getPassName() const override { return "SPIRV prepare functions"; }
47
getAnalysisUsage(AnalysisUsage & AU) const48 void getAnalysisUsage(AnalysisUsage &AU) const override {
49 ModulePass::getAnalysisUsage(AU);
50 }
51 };
52
53 } // namespace
54
55 char SPIRVPrepareFunctions::ID = 0;
56
57 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
58 "SPIRV prepare functions", false, false)
59
processFunctionSignature(Function * F)60 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
61 IRBuilder<> B(F->getContext());
62
63 bool IsRetAggr = F->getReturnType()->isAggregateType();
64 bool HasAggrArg =
65 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
66 return Arg.getType()->isAggregateType();
67 });
68 bool DoClone = IsRetAggr || HasAggrArg;
69 if (!DoClone)
70 return F;
71 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
72 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
73 if (IsRetAggr)
74 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
75 SmallVector<Type *, 4> ArgTypes;
76 for (const auto &Arg : F->args()) {
77 if (Arg.getType()->isAggregateType()) {
78 ArgTypes.push_back(B.getInt32Ty());
79 ChangedTypes.push_back(
80 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
81 } else
82 ArgTypes.push_back(Arg.getType());
83 }
84 FunctionType *NewFTy =
85 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
86 Function *NewF =
87 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
88
89 ValueToValueMapTy VMap;
90 auto NewFArgIt = NewF->arg_begin();
91 for (auto &Arg : F->args()) {
92 StringRef ArgName = Arg.getName();
93 NewFArgIt->setName(ArgName);
94 VMap[&Arg] = &(*NewFArgIt++);
95 }
96 SmallVector<ReturnInst *, 8> Returns;
97
98 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
99 Returns);
100 NewF->takeName(F);
101
102 NamedMDNode *FuncMD =
103 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
104 SmallVector<Metadata *, 2> MDArgs;
105 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
106 for (auto &ChangedTyP : ChangedTypes)
107 MDArgs.push_back(MDNode::get(
108 B.getContext(),
109 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
110 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
111 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
112 FuncMD->addOperand(ThisFuncMD);
113
114 for (auto *U : make_early_inc_range(F->users())) {
115 if (auto *CI = dyn_cast<CallInst>(U))
116 CI->mutateFunctionType(NewF->getFunctionType());
117 U->replaceUsesOfWith(F, NewF);
118 }
119 return NewF;
120 }
121
lowerLLVMIntrinsicName(IntrinsicInst * II)122 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
123 Function *IntrinsicFunc = II->getCalledFunction();
124 assert(IntrinsicFunc && "Missing function");
125 std::string FuncName = IntrinsicFunc->getName().str();
126 std::replace(FuncName.begin(), FuncName.end(), '.', '_');
127 FuncName = "spirv." + FuncName;
128 return FuncName;
129 }
130
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name)131 static Function *getOrCreateFunction(Module *M, Type *RetTy,
132 ArrayRef<Type *> ArgTypes,
133 StringRef Name) {
134 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
135 Function *F = M->getFunction(Name);
136 if (F && F->getFunctionType() == FT)
137 return F;
138 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
139 if (F)
140 NewF->setDSOLocal(F->isDSOLocal());
141 NewF->setCallingConv(CallingConv::SPIR_FUNC);
142 return NewF;
143 }
144
lowerIntrinsicToFunction(Module * M,IntrinsicInst * Intrinsic)145 static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) {
146 // For @llvm.memset.* intrinsic cases with constant value and length arguments
147 // are emulated via "storing" a constant array to the destination. For other
148 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
149 // intrinsic to a loop via expandMemSetAsLoop().
150 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
151 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
152 return; // It is handled later using OpCopyMemorySized.
153
154 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
155 if (Intrinsic->isVolatile())
156 FuncName += ".volatile";
157 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
158 Function *F = M->getFunction(FuncName);
159 if (F) {
160 Intrinsic->setCalledFunction(F);
161 return;
162 }
163 // TODO copy arguments attributes: nocapture writeonly.
164 FunctionCallee FC =
165 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
166 auto IntrinsicID = Intrinsic->getIntrinsicID();
167 Intrinsic->setCalledFunction(FC);
168
169 F = dyn_cast<Function>(FC.getCallee());
170 assert(F && "Callee must be a function");
171
172 switch (IntrinsicID) {
173 case Intrinsic::memset: {
174 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
175 Argument *Dest = F->getArg(0);
176 Argument *Val = F->getArg(1);
177 Argument *Len = F->getArg(2);
178 Argument *IsVolatile = F->getArg(3);
179 Dest->setName("dest");
180 Val->setName("val");
181 Len->setName("len");
182 IsVolatile->setName("isvolatile");
183 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
184 IRBuilder<> IRB(EntryBB);
185 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
186 MSI->isVolatile());
187 IRB.CreateRetVoid();
188 expandMemSetAsLoop(cast<MemSetInst>(MemSet));
189 MemSet->eraseFromParent();
190 break;
191 }
192 case Intrinsic::bswap: {
193 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
194 IRBuilder<> IRB(EntryBB);
195 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
196 F->getArg(0));
197 IRB.CreateRet(BSwap);
198 IntrinsicLowering IL(M->getDataLayout());
199 IL.LowerIntrinsicCall(BSwap);
200 break;
201 }
202 default:
203 break;
204 }
205 return;
206 }
207
lowerFunnelShifts(Module * M,IntrinsicInst * FSHIntrinsic)208 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
209 // Get a separate function - otherwise, we'd have to rework the CFG of the
210 // current one. Then simply replace the intrinsic uses with a call to the new
211 // function.
212 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
213 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
214 Type *FSHRetTy = FSHFuncTy->getReturnType();
215 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
216 Function *FSHFunc =
217 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
218
219 if (!FSHFunc->empty()) {
220 FSHIntrinsic->setCalledFunction(FSHFunc);
221 return;
222 }
223 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
224 IRBuilder<> IRB(RotateBB);
225 Type *Ty = FSHFunc->getReturnType();
226 // Build the actual funnel shift rotate logic.
227 // In the comments, "int" is used interchangeably with "vector of int
228 // elements".
229 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
230 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
231 unsigned BitWidth = IntTy->getIntegerBitWidth();
232 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
233 Value *BitWidthForInsts =
234 VectorTy
235 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
236 : BitWidthConstant;
237 Value *RotateModVal =
238 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
239 Value *FirstShift = nullptr, *SecShift = nullptr;
240 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
241 // Shift the less significant number right, the "rotate" number of bits
242 // will be 0-filled on the left as a result of this regular shift.
243 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
244 } else {
245 // Shift the more significant number left, the "rotate" number of bits
246 // will be 0-filled on the right as a result of this regular shift.
247 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
248 }
249 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
250 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
251 // Therefore, subtract the "rotate" number from the integer bitsize...
252 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
253 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
254 // ...and left-shift the more significant int by this number, zero-filling
255 // the LSBs.
256 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
257 } else {
258 // ...and right-shift the less significant int by this number, zero-filling
259 // the MSBs.
260 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
261 }
262 // A simple binary addition of the shifted ints yields the final result.
263 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
264
265 FSHIntrinsic->setCalledFunction(FSHFunc);
266 }
267
buildUMulWithOverflowFunc(Module * M,Function * UMulFunc)268 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
269 // The function body is already created.
270 if (!UMulFunc->empty())
271 return;
272
273 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
274 IRBuilder<> IRB(EntryBB);
275 // Build the actual unsigned multiplication logic with the overflow
276 // indication. Do unsigned multiplication Mul = A * B. Then check
277 // if unsigned division Div = Mul / A is not equal to B. If so,
278 // then overflow has happened.
279 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
280 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
281 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
282
283 // umul.with.overflow intrinsic return a structure, where the first element
284 // is the multiplication result, and the second is an overflow bit.
285 Type *StructTy = UMulFunc->getReturnType();
286 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
287 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
288 IRB.CreateRet(Res);
289 }
290
lowerUMulWithOverflow(Module * M,IntrinsicInst * UMulIntrinsic)291 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
292 // Get a separate function - otherwise, we'd have to rework the CFG of the
293 // current one. Then simply replace the intrinsic uses with a call to the new
294 // function.
295 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
296 Type *FSHLRetTy = UMulFuncTy->getReturnType();
297 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
298 Function *UMulFunc =
299 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
300 buildUMulWithOverflowFunc(M, UMulFunc);
301 UMulIntrinsic->setCalledFunction(UMulFunc);
302 }
303
substituteIntrinsicCalls(Module * M,Function * F)304 static void substituteIntrinsicCalls(Module *M, Function *F) {
305 for (BasicBlock &BB : *F) {
306 for (Instruction &I : BB) {
307 auto Call = dyn_cast<CallInst>(&I);
308 if (!Call)
309 continue;
310 Call->setTailCall(false);
311 Function *CF = Call->getCalledFunction();
312 if (!CF || !CF->isIntrinsic())
313 continue;
314 auto *II = cast<IntrinsicInst>(Call);
315 if (II->getIntrinsicID() == Intrinsic::memset ||
316 II->getIntrinsicID() == Intrinsic::bswap)
317 lowerIntrinsicToFunction(M, II);
318 else if (II->getIntrinsicID() == Intrinsic::fshl ||
319 II->getIntrinsicID() == Intrinsic::fshr)
320 lowerFunnelShifts(M, II);
321 else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
322 lowerUMulWithOverflow(M, II);
323 }
324 }
325 }
326
runOnModule(Module & M)327 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
328 for (Function &F : M)
329 substituteIntrinsicCalls(&M, &F);
330
331 std::vector<Function *> FuncsWorklist;
332 bool Changed = false;
333 for (auto &F : M)
334 FuncsWorklist.push_back(&F);
335
336 for (auto *Func : FuncsWorklist) {
337 Function *F = processFunctionSignature(Func);
338
339 bool CreatedNewF = F != Func;
340
341 if (Func->isDeclaration()) {
342 Changed |= CreatedNewF;
343 continue;
344 }
345
346 if (CreatedNewF)
347 Func->eraseFromParent();
348 }
349
350 return Changed;
351 }
352
createSPIRVPrepareFunctionsPass()353 ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
354 return new SPIRVPrepareFunctions();
355 }
356