xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2fcaf7f86SDimitry Andric //
3fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fcaf7f86SDimitry Andric //
7fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
8fcaf7f86SDimitry Andric //
9fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments
1006c3fb27SDimitry Andric // and/or return value before IRTranslator. Information about the original
1106c3fb27SDimitry Andric // signatures is stored in metadata. It is used during call lowering to
1206c3fb27SDimitry Andric // restore correct SPIR-V types of function arguments and return values.
1306c3fb27SDimitry Andric // This pass also substitutes some llvm intrinsic calls with calls to newly
1406c3fb27SDimitry Andric // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15fcaf7f86SDimitry Andric //
16fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify
17fcaf7f86SDimitry Andric // GVs/functions.
18fcaf7f86SDimitry Andric //
19fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
20fcaf7f86SDimitry Andric 
21fcaf7f86SDimitry Andric #include "SPIRV.h"
225f757f3fSDimitry Andric #include "SPIRVSubtarget.h"
23fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h"
24fcaf7f86SDimitry Andric #include "SPIRVUtils.h"
25*0fca6ea1SDimitry Andric #include "llvm/Analysis/ValueTracking.h"
26bdd1243dSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
27fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h"
28fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
295f757f3fSDimitry Andric #include "llvm/IR/Intrinsics.h"
305f757f3fSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
31fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
32fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
33*0fca6ea1SDimitry Andric #include <charconv>
34*0fca6ea1SDimitry Andric #include <regex>
35fcaf7f86SDimitry Andric 
36fcaf7f86SDimitry Andric using namespace llvm;
37fcaf7f86SDimitry Andric 
38fcaf7f86SDimitry Andric namespace llvm {
39fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
40fcaf7f86SDimitry Andric }
41fcaf7f86SDimitry Andric 
42fcaf7f86SDimitry Andric namespace {
43fcaf7f86SDimitry Andric 
44fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass {
455f757f3fSDimitry Andric   const SPIRVTargetMachine &TM;
4606c3fb27SDimitry Andric   bool substituteIntrinsicCalls(Function *F);
4706c3fb27SDimitry Andric   Function *removeAggregateTypesFromSignature(Function *F);
48fcaf7f86SDimitry Andric 
49fcaf7f86SDimitry Andric public:
50fcaf7f86SDimitry Andric   static char ID;
515f757f3fSDimitry Andric   SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
52fcaf7f86SDimitry Andric     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
53fcaf7f86SDimitry Andric   }
54fcaf7f86SDimitry Andric 
55fcaf7f86SDimitry Andric   bool runOnModule(Module &M) override;
56fcaf7f86SDimitry Andric 
57fcaf7f86SDimitry Andric   StringRef getPassName() const override { return "SPIRV prepare functions"; }
58fcaf7f86SDimitry Andric 
59fcaf7f86SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
60fcaf7f86SDimitry Andric     ModulePass::getAnalysisUsage(AU);
61fcaf7f86SDimitry Andric   }
62fcaf7f86SDimitry Andric };
63fcaf7f86SDimitry Andric 
64fcaf7f86SDimitry Andric } // namespace
65fcaf7f86SDimitry Andric 
66fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0;
67fcaf7f86SDimitry Andric 
68fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
69fcaf7f86SDimitry Andric                 "SPIRV prepare functions", false, false)
70fcaf7f86SDimitry Andric 
7106c3fb27SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
7206c3fb27SDimitry Andric   Function *IntrinsicFunc = II->getCalledFunction();
7306c3fb27SDimitry Andric   assert(IntrinsicFunc && "Missing function");
7406c3fb27SDimitry Andric   std::string FuncName = IntrinsicFunc->getName().str();
7506c3fb27SDimitry Andric   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
7606c3fb27SDimitry Andric   FuncName = "spirv." + FuncName;
7706c3fb27SDimitry Andric   return FuncName;
7806c3fb27SDimitry Andric }
7906c3fb27SDimitry Andric 
8006c3fb27SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy,
8106c3fb27SDimitry Andric                                      ArrayRef<Type *> ArgTypes,
8206c3fb27SDimitry Andric                                      StringRef Name) {
8306c3fb27SDimitry Andric   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
8406c3fb27SDimitry Andric   Function *F = M->getFunction(Name);
8506c3fb27SDimitry Andric   if (F && F->getFunctionType() == FT)
8606c3fb27SDimitry Andric     return F;
8706c3fb27SDimitry Andric   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
8806c3fb27SDimitry Andric   if (F)
8906c3fb27SDimitry Andric     NewF->setDSOLocal(F->isDSOLocal());
9006c3fb27SDimitry Andric   NewF->setCallingConv(CallingConv::SPIR_FUNC);
9106c3fb27SDimitry Andric   return NewF;
9206c3fb27SDimitry Andric }
9306c3fb27SDimitry Andric 
9406c3fb27SDimitry Andric static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
9506c3fb27SDimitry Andric   // For @llvm.memset.* intrinsic cases with constant value and length arguments
9606c3fb27SDimitry Andric   // are emulated via "storing" a constant array to the destination. For other
9706c3fb27SDimitry Andric   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
9806c3fb27SDimitry Andric   // intrinsic to a loop via expandMemSetAsLoop().
9906c3fb27SDimitry Andric   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
10006c3fb27SDimitry Andric     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
10106c3fb27SDimitry Andric       return false; // It is handled later using OpCopyMemorySized.
10206c3fb27SDimitry Andric 
10306c3fb27SDimitry Andric   Module *M = Intrinsic->getModule();
10406c3fb27SDimitry Andric   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
10506c3fb27SDimitry Andric   if (Intrinsic->isVolatile())
10606c3fb27SDimitry Andric     FuncName += ".volatile";
10706c3fb27SDimitry Andric   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
10806c3fb27SDimitry Andric   Function *F = M->getFunction(FuncName);
10906c3fb27SDimitry Andric   if (F) {
11006c3fb27SDimitry Andric     Intrinsic->setCalledFunction(F);
11106c3fb27SDimitry Andric     return true;
11206c3fb27SDimitry Andric   }
11306c3fb27SDimitry Andric   // TODO copy arguments attributes: nocapture writeonly.
11406c3fb27SDimitry Andric   FunctionCallee FC =
11506c3fb27SDimitry Andric       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
11606c3fb27SDimitry Andric   auto IntrinsicID = Intrinsic->getIntrinsicID();
11706c3fb27SDimitry Andric   Intrinsic->setCalledFunction(FC);
11806c3fb27SDimitry Andric 
11906c3fb27SDimitry Andric   F = dyn_cast<Function>(FC.getCallee());
12006c3fb27SDimitry Andric   assert(F && "Callee must be a function");
12106c3fb27SDimitry Andric 
12206c3fb27SDimitry Andric   switch (IntrinsicID) {
12306c3fb27SDimitry Andric   case Intrinsic::memset: {
12406c3fb27SDimitry Andric     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
12506c3fb27SDimitry Andric     Argument *Dest = F->getArg(0);
12606c3fb27SDimitry Andric     Argument *Val = F->getArg(1);
12706c3fb27SDimitry Andric     Argument *Len = F->getArg(2);
12806c3fb27SDimitry Andric     Argument *IsVolatile = F->getArg(3);
12906c3fb27SDimitry Andric     Dest->setName("dest");
13006c3fb27SDimitry Andric     Val->setName("val");
13106c3fb27SDimitry Andric     Len->setName("len");
13206c3fb27SDimitry Andric     IsVolatile->setName("isvolatile");
13306c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
13406c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
13506c3fb27SDimitry Andric     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
13606c3fb27SDimitry Andric                                     MSI->isVolatile());
13706c3fb27SDimitry Andric     IRB.CreateRetVoid();
13806c3fb27SDimitry Andric     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
13906c3fb27SDimitry Andric     MemSet->eraseFromParent();
14006c3fb27SDimitry Andric     break;
14106c3fb27SDimitry Andric   }
14206c3fb27SDimitry Andric   case Intrinsic::bswap: {
14306c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
14406c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
14506c3fb27SDimitry Andric     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
14606c3fb27SDimitry Andric                                       F->getArg(0));
14706c3fb27SDimitry Andric     IRB.CreateRet(BSwap);
14806c3fb27SDimitry Andric     IntrinsicLowering IL(M->getDataLayout());
14906c3fb27SDimitry Andric     IL.LowerIntrinsicCall(BSwap);
15006c3fb27SDimitry Andric     break;
15106c3fb27SDimitry Andric   }
15206c3fb27SDimitry Andric   default:
15306c3fb27SDimitry Andric     break;
15406c3fb27SDimitry Andric   }
15506c3fb27SDimitry Andric   return true;
15606c3fb27SDimitry Andric }
15706c3fb27SDimitry Andric 
158*0fca6ea1SDimitry Andric static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
159*0fca6ea1SDimitry Andric   if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))
160*0fca6ea1SDimitry Andric     AnnoVal = Ref->getOperand(0);
161*0fca6ea1SDimitry Andric   if (auto *Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))
162*0fca6ea1SDimitry Andric     OptAnnoVal = Ref->getOperand(0);
163*0fca6ea1SDimitry Andric 
164*0fca6ea1SDimitry Andric   std::string Anno;
165*0fca6ea1SDimitry Andric   if (auto *C = dyn_cast_or_null<Constant>(AnnoVal)) {
166*0fca6ea1SDimitry Andric     StringRef Str;
167*0fca6ea1SDimitry Andric     if (getConstantStringInfo(C, Str))
168*0fca6ea1SDimitry Andric       Anno = Str;
169*0fca6ea1SDimitry Andric   }
170*0fca6ea1SDimitry Andric   // handle optional annotation parameter in a way that Khronos Translator do
171*0fca6ea1SDimitry Andric   // (collect integers wrapped in a struct)
172*0fca6ea1SDimitry Andric   if (auto *C = dyn_cast_or_null<Constant>(OptAnnoVal);
173*0fca6ea1SDimitry Andric       C && C->getNumOperands()) {
174*0fca6ea1SDimitry Andric     Value *MaybeStruct = C->getOperand(0);
175*0fca6ea1SDimitry Andric     if (auto *Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {
176*0fca6ea1SDimitry Andric       for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
177*0fca6ea1SDimitry Andric         if (auto *CInt = dyn_cast<ConstantInt>(Struct->getOperand(I)))
178*0fca6ea1SDimitry Andric           Anno += (I == 0 ? ": " : ", ") +
179*0fca6ea1SDimitry Andric                   std::to_string(CInt->getType()->getIntegerBitWidth() == 1
180*0fca6ea1SDimitry Andric                                      ? CInt->getZExtValue()
181*0fca6ea1SDimitry Andric                                      : CInt->getSExtValue());
182*0fca6ea1SDimitry Andric       }
183*0fca6ea1SDimitry Andric     } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {
184*0fca6ea1SDimitry Andric       // { i32 i32 ... } zeroinitializer
185*0fca6ea1SDimitry Andric       for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
186*0fca6ea1SDimitry Andric            I != E; ++I)
187*0fca6ea1SDimitry Andric         Anno += I == 0 ? ": 0" : ", 0";
188*0fca6ea1SDimitry Andric     }
189*0fca6ea1SDimitry Andric   }
190*0fca6ea1SDimitry Andric   return Anno;
191*0fca6ea1SDimitry Andric }
192*0fca6ea1SDimitry Andric 
193*0fca6ea1SDimitry Andric static SmallVector<Metadata *> parseAnnotation(Value *I,
194*0fca6ea1SDimitry Andric                                                const std::string &Anno,
195*0fca6ea1SDimitry Andric                                                LLVMContext &Ctx,
196*0fca6ea1SDimitry Andric                                                Type *Int32Ty) {
197*0fca6ea1SDimitry Andric   // Try to parse the annotation string according to the following rules:
198*0fca6ea1SDimitry Andric   // annotation := ({kind} | {kind:value,value,...})+
199*0fca6ea1SDimitry Andric   // kind := number
200*0fca6ea1SDimitry Andric   // value := number | string
201*0fca6ea1SDimitry Andric   static const std::regex R(
202*0fca6ea1SDimitry Andric       "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
203*0fca6ea1SDimitry Andric   SmallVector<Metadata *> MDs;
204*0fca6ea1SDimitry Andric   int Pos = 0;
205*0fca6ea1SDimitry Andric   for (std::sregex_iterator
206*0fca6ea1SDimitry Andric            It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
207*0fca6ea1SDimitry Andric            ItEnd = std::sregex_iterator();
208*0fca6ea1SDimitry Andric        It != ItEnd; ++It) {
209*0fca6ea1SDimitry Andric     if (It->position() != Pos)
210*0fca6ea1SDimitry Andric       return SmallVector<Metadata *>{};
211*0fca6ea1SDimitry Andric     Pos = It->position() + It->length();
212*0fca6ea1SDimitry Andric     std::smatch Match = *It;
213*0fca6ea1SDimitry Andric     SmallVector<Metadata *> MDsItem;
214*0fca6ea1SDimitry Andric     for (std::size_t i = 1; i < Match.size(); ++i) {
215*0fca6ea1SDimitry Andric       std::ssub_match SMatch = Match[i];
216*0fca6ea1SDimitry Andric       std::string Item = SMatch.str();
217*0fca6ea1SDimitry Andric       if (Item.length() == 0)
218*0fca6ea1SDimitry Andric         break;
219*0fca6ea1SDimitry Andric       if (Item[0] == '"') {
220*0fca6ea1SDimitry Andric         Item = Item.substr(1, Item.length() - 2);
221*0fca6ea1SDimitry Andric         // Acceptable format of the string snippet is:
222*0fca6ea1SDimitry Andric         static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
223*0fca6ea1SDimitry Andric         if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {
224*0fca6ea1SDimitry Andric           for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
225*0fca6ea1SDimitry Andric             if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
226*0fca6ea1SDimitry Andric               MDsItem.push_back(ConstantAsMetadata::get(
227*0fca6ea1SDimitry Andric                   ConstantInt::get(Int32Ty, std::stoi(SubStr))));
228*0fca6ea1SDimitry Andric         } else {
229*0fca6ea1SDimitry Andric           MDsItem.push_back(MDString::get(Ctx, Item));
230*0fca6ea1SDimitry Andric         }
231*0fca6ea1SDimitry Andric       } else if (int32_t Num;
232*0fca6ea1SDimitry Andric                  std::from_chars(Item.data(), Item.data() + Item.size(), Num)
233*0fca6ea1SDimitry Andric                      .ec == std::errc{}) {
234*0fca6ea1SDimitry Andric         MDsItem.push_back(
235*0fca6ea1SDimitry Andric             ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Num)));
236*0fca6ea1SDimitry Andric       } else {
237*0fca6ea1SDimitry Andric         MDsItem.push_back(MDString::get(Ctx, Item));
238*0fca6ea1SDimitry Andric       }
239*0fca6ea1SDimitry Andric     }
240*0fca6ea1SDimitry Andric     if (MDsItem.size() == 0)
241*0fca6ea1SDimitry Andric       return SmallVector<Metadata *>{};
242*0fca6ea1SDimitry Andric     MDs.push_back(MDNode::get(Ctx, MDsItem));
243*0fca6ea1SDimitry Andric   }
244*0fca6ea1SDimitry Andric   return Pos == static_cast<int>(Anno.length()) ? MDs
245*0fca6ea1SDimitry Andric                                                 : SmallVector<Metadata *>{};
246*0fca6ea1SDimitry Andric }
247*0fca6ea1SDimitry Andric 
248*0fca6ea1SDimitry Andric static void lowerPtrAnnotation(IntrinsicInst *II) {
249*0fca6ea1SDimitry Andric   LLVMContext &Ctx = II->getContext();
250*0fca6ea1SDimitry Andric   Type *Int32Ty = Type::getInt32Ty(Ctx);
251*0fca6ea1SDimitry Andric 
252*0fca6ea1SDimitry Andric   // Retrieve an annotation string from arguments.
253*0fca6ea1SDimitry Andric   Value *PtrArg = nullptr;
254*0fca6ea1SDimitry Andric   if (auto *BI = dyn_cast<BitCastInst>(II->getArgOperand(0)))
255*0fca6ea1SDimitry Andric     PtrArg = BI->getOperand(0);
256*0fca6ea1SDimitry Andric   else
257*0fca6ea1SDimitry Andric     PtrArg = II->getOperand(0);
258*0fca6ea1SDimitry Andric   std::string Anno =
259*0fca6ea1SDimitry Andric       getAnnotation(II->getArgOperand(1),
260*0fca6ea1SDimitry Andric                     4 < II->arg_size() ? II->getArgOperand(4) : nullptr);
261*0fca6ea1SDimitry Andric 
262*0fca6ea1SDimitry Andric   // Parse the annotation.
263*0fca6ea1SDimitry Andric   SmallVector<Metadata *> MDs = parseAnnotation(II, Anno, Ctx, Int32Ty);
264*0fca6ea1SDimitry Andric 
265*0fca6ea1SDimitry Andric   // If the annotation string is not parsed successfully we don't know the
266*0fca6ea1SDimitry Andric   // format used and output it as a general UserSemantic decoration.
267*0fca6ea1SDimitry Andric   // Otherwise MDs is a Metadata tuple (a decoration list) in the format
268*0fca6ea1SDimitry Andric   // expected by `spirv.Decorations`.
269*0fca6ea1SDimitry Andric   if (MDs.size() == 0) {
270*0fca6ea1SDimitry Andric     auto UserSemantic = ConstantAsMetadata::get(ConstantInt::get(
271*0fca6ea1SDimitry Andric         Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
272*0fca6ea1SDimitry Andric     MDs.push_back(MDNode::get(Ctx, {UserSemantic, MDString::get(Ctx, Anno)}));
273*0fca6ea1SDimitry Andric   }
274*0fca6ea1SDimitry Andric 
275*0fca6ea1SDimitry Andric   // Build the internal intrinsic function.
276*0fca6ea1SDimitry Andric   IRBuilder<> IRB(II->getParent());
277*0fca6ea1SDimitry Andric   IRB.SetInsertPoint(II);
278*0fca6ea1SDimitry Andric   IRB.CreateIntrinsic(
279*0fca6ea1SDimitry Andric       Intrinsic::spv_assign_decoration, {PtrArg->getType()},
280*0fca6ea1SDimitry Andric       {PtrArg, MetadataAsValue::get(Ctx, MDNode::get(Ctx, MDs))});
281*0fca6ea1SDimitry Andric   II->replaceAllUsesWith(II->getOperand(0));
282*0fca6ea1SDimitry Andric }
283*0fca6ea1SDimitry Andric 
28406c3fb27SDimitry Andric static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
28506c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
28606c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
28706c3fb27SDimitry Andric   // function.
28806c3fb27SDimitry Andric   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
28906c3fb27SDimitry Andric   Module *M = FSHIntrinsic->getModule();
29006c3fb27SDimitry Andric   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
29106c3fb27SDimitry Andric   Type *FSHRetTy = FSHFuncTy->getReturnType();
29206c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
29306c3fb27SDimitry Andric   Function *FSHFunc =
29406c3fb27SDimitry Andric       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
29506c3fb27SDimitry Andric 
29606c3fb27SDimitry Andric   if (!FSHFunc->empty()) {
29706c3fb27SDimitry Andric     FSHIntrinsic->setCalledFunction(FSHFunc);
29806c3fb27SDimitry Andric     return;
29906c3fb27SDimitry Andric   }
30006c3fb27SDimitry Andric   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
30106c3fb27SDimitry Andric   IRBuilder<> IRB(RotateBB);
30206c3fb27SDimitry Andric   Type *Ty = FSHFunc->getReturnType();
30306c3fb27SDimitry Andric   // Build the actual funnel shift rotate logic.
30406c3fb27SDimitry Andric   // In the comments, "int" is used interchangeably with "vector of int
30506c3fb27SDimitry Andric   // elements".
30606c3fb27SDimitry Andric   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
30706c3fb27SDimitry Andric   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
30806c3fb27SDimitry Andric   unsigned BitWidth = IntTy->getIntegerBitWidth();
30906c3fb27SDimitry Andric   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
31006c3fb27SDimitry Andric   Value *BitWidthForInsts =
31106c3fb27SDimitry Andric       VectorTy
31206c3fb27SDimitry Andric           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
31306c3fb27SDimitry Andric           : BitWidthConstant;
31406c3fb27SDimitry Andric   Value *RotateModVal =
31506c3fb27SDimitry Andric       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
31606c3fb27SDimitry Andric   Value *FirstShift = nullptr, *SecShift = nullptr;
31706c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
31806c3fb27SDimitry Andric     // Shift the less significant number right, the "rotate" number of bits
31906c3fb27SDimitry Andric     // will be 0-filled on the left as a result of this regular shift.
32006c3fb27SDimitry Andric     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
32106c3fb27SDimitry Andric   } else {
32206c3fb27SDimitry Andric     // Shift the more significant number left, the "rotate" number of bits
32306c3fb27SDimitry Andric     // will be 0-filled on the right as a result of this regular shift.
32406c3fb27SDimitry Andric     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
32506c3fb27SDimitry Andric   }
32606c3fb27SDimitry Andric   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
32706c3fb27SDimitry Andric   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
32806c3fb27SDimitry Andric   // Therefore, subtract the "rotate" number from the integer bitsize...
32906c3fb27SDimitry Andric   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
33006c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
33106c3fb27SDimitry Andric     // ...and left-shift the more significant int by this number, zero-filling
33206c3fb27SDimitry Andric     // the LSBs.
33306c3fb27SDimitry Andric     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
33406c3fb27SDimitry Andric   } else {
33506c3fb27SDimitry Andric     // ...and right-shift the less significant int by this number, zero-filling
33606c3fb27SDimitry Andric     // the MSBs.
33706c3fb27SDimitry Andric     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
33806c3fb27SDimitry Andric   }
33906c3fb27SDimitry Andric   // A simple binary addition of the shifted ints yields the final result.
34006c3fb27SDimitry Andric   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
34106c3fb27SDimitry Andric 
34206c3fb27SDimitry Andric   FSHIntrinsic->setCalledFunction(FSHFunc);
34306c3fb27SDimitry Andric }
34406c3fb27SDimitry Andric 
34506c3fb27SDimitry Andric static void buildUMulWithOverflowFunc(Function *UMulFunc) {
34606c3fb27SDimitry Andric   // The function body is already created.
34706c3fb27SDimitry Andric   if (!UMulFunc->empty())
34806c3fb27SDimitry Andric     return;
34906c3fb27SDimitry Andric 
35006c3fb27SDimitry Andric   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
35106c3fb27SDimitry Andric                                            "entry", UMulFunc);
35206c3fb27SDimitry Andric   IRBuilder<> IRB(EntryBB);
35306c3fb27SDimitry Andric   // Build the actual unsigned multiplication logic with the overflow
35406c3fb27SDimitry Andric   // indication. Do unsigned multiplication Mul = A * B. Then check
35506c3fb27SDimitry Andric   // if unsigned division Div = Mul / A is not equal to B. If so,
35606c3fb27SDimitry Andric   // then overflow has happened.
35706c3fb27SDimitry Andric   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
35806c3fb27SDimitry Andric   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
35906c3fb27SDimitry Andric   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
36006c3fb27SDimitry Andric 
36106c3fb27SDimitry Andric   // umul.with.overflow intrinsic return a structure, where the first element
36206c3fb27SDimitry Andric   // is the multiplication result, and the second is an overflow bit.
36306c3fb27SDimitry Andric   Type *StructTy = UMulFunc->getReturnType();
36406c3fb27SDimitry Andric   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
36506c3fb27SDimitry Andric   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
36606c3fb27SDimitry Andric   IRB.CreateRet(Res);
36706c3fb27SDimitry Andric }
36806c3fb27SDimitry Andric 
3695f757f3fSDimitry Andric static void lowerExpectAssume(IntrinsicInst *II) {
3705f757f3fSDimitry Andric   // If we cannot use the SPV_KHR_expect_assume extension, then we need to
3715f757f3fSDimitry Andric   // ignore the intrinsic and move on. It should be removed later on by LLVM.
3725f757f3fSDimitry Andric   // Otherwise we should lower the intrinsic to the corresponding SPIR-V
3735f757f3fSDimitry Andric   // instruction.
3745f757f3fSDimitry Andric   // For @llvm.assume we have OpAssumeTrueKHR.
3755f757f3fSDimitry Andric   // For @llvm.expect we have OpExpectKHR.
3765f757f3fSDimitry Andric   //
3775f757f3fSDimitry Andric   // We need to lower this into a builtin and then the builtin into a SPIR-V
3785f757f3fSDimitry Andric   // instruction.
3795f757f3fSDimitry Andric   if (II->getIntrinsicID() == Intrinsic::assume) {
3805f757f3fSDimitry Andric     Function *F = Intrinsic::getDeclaration(
3815f757f3fSDimitry Andric         II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
3825f757f3fSDimitry Andric     II->setCalledFunction(F);
3835f757f3fSDimitry Andric   } else if (II->getIntrinsicID() == Intrinsic::expect) {
3845f757f3fSDimitry Andric     Function *F = Intrinsic::getDeclaration(
3855f757f3fSDimitry Andric         II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
3865f757f3fSDimitry Andric         {II->getOperand(0)->getType()});
3875f757f3fSDimitry Andric     II->setCalledFunction(F);
3885f757f3fSDimitry Andric   } else {
3895f757f3fSDimitry Andric     llvm_unreachable("Unknown intrinsic");
3905f757f3fSDimitry Andric   }
3915f757f3fSDimitry Andric 
3925f757f3fSDimitry Andric   return;
3935f757f3fSDimitry Andric }
3945f757f3fSDimitry Andric 
395*0fca6ea1SDimitry Andric static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
396*0fca6ea1SDimitry Andric                                      ArrayRef<unsigned> OpNos) {
397*0fca6ea1SDimitry Andric   Function *F = nullptr;
398*0fca6ea1SDimitry Andric   if (OpNos.empty()) {
399*0fca6ea1SDimitry Andric     F = Intrinsic::getDeclaration(II->getModule(), NewID);
400*0fca6ea1SDimitry Andric   } else {
401*0fca6ea1SDimitry Andric     SmallVector<Type *, 4> Tys;
402*0fca6ea1SDimitry Andric     for (unsigned OpNo : OpNos)
403*0fca6ea1SDimitry Andric       Tys.push_back(II->getOperand(OpNo)->getType());
404*0fca6ea1SDimitry Andric     F = Intrinsic::getDeclaration(II->getModule(), NewID, Tys);
405*0fca6ea1SDimitry Andric   }
406*0fca6ea1SDimitry Andric   II->setCalledFunction(F);
407*0fca6ea1SDimitry Andric   return true;
408*0fca6ea1SDimitry Andric }
409*0fca6ea1SDimitry Andric 
41006c3fb27SDimitry Andric static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
41106c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
41206c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
41306c3fb27SDimitry Andric   // function.
41406c3fb27SDimitry Andric   Module *M = UMulIntrinsic->getModule();
41506c3fb27SDimitry Andric   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
41606c3fb27SDimitry Andric   Type *FSHLRetTy = UMulFuncTy->getReturnType();
41706c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
41806c3fb27SDimitry Andric   Function *UMulFunc =
41906c3fb27SDimitry Andric       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
42006c3fb27SDimitry Andric   buildUMulWithOverflowFunc(UMulFunc);
42106c3fb27SDimitry Andric   UMulIntrinsic->setCalledFunction(UMulFunc);
42206c3fb27SDimitry Andric }
42306c3fb27SDimitry Andric 
42406c3fb27SDimitry Andric // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
42506c3fb27SDimitry Andric // or calls to proper generated functions. Returns True if F was modified.
42606c3fb27SDimitry Andric bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
42706c3fb27SDimitry Andric   bool Changed = false;
42806c3fb27SDimitry Andric   for (BasicBlock &BB : *F) {
42906c3fb27SDimitry Andric     for (Instruction &I : BB) {
43006c3fb27SDimitry Andric       auto Call = dyn_cast<CallInst>(&I);
43106c3fb27SDimitry Andric       if (!Call)
43206c3fb27SDimitry Andric         continue;
43306c3fb27SDimitry Andric       Function *CF = Call->getCalledFunction();
43406c3fb27SDimitry Andric       if (!CF || !CF->isIntrinsic())
43506c3fb27SDimitry Andric         continue;
43606c3fb27SDimitry Andric       auto *II = cast<IntrinsicInst>(Call);
437*0fca6ea1SDimitry Andric       switch (II->getIntrinsicID()) {
438*0fca6ea1SDimitry Andric       case Intrinsic::memset:
439*0fca6ea1SDimitry Andric       case Intrinsic::bswap:
44006c3fb27SDimitry Andric         Changed |= lowerIntrinsicToFunction(II);
441*0fca6ea1SDimitry Andric         break;
442*0fca6ea1SDimitry Andric       case Intrinsic::fshl:
443*0fca6ea1SDimitry Andric       case Intrinsic::fshr:
44406c3fb27SDimitry Andric         lowerFunnelShifts(II);
44506c3fb27SDimitry Andric         Changed = true;
446*0fca6ea1SDimitry Andric         break;
447*0fca6ea1SDimitry Andric       case Intrinsic::umul_with_overflow:
44806c3fb27SDimitry Andric         lowerUMulWithOverflow(II);
44906c3fb27SDimitry Andric         Changed = true;
450*0fca6ea1SDimitry Andric         break;
451*0fca6ea1SDimitry Andric       case Intrinsic::assume:
452*0fca6ea1SDimitry Andric       case Intrinsic::expect: {
4535f757f3fSDimitry Andric         const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
4545f757f3fSDimitry Andric         if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
4555f757f3fSDimitry Andric           lowerExpectAssume(II);
4565f757f3fSDimitry Andric         Changed = true;
457*0fca6ea1SDimitry Andric       } break;
458*0fca6ea1SDimitry Andric       case Intrinsic::lifetime_start:
459*0fca6ea1SDimitry Andric         Changed |= toSpvOverloadedIntrinsic(
460*0fca6ea1SDimitry Andric             II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
461*0fca6ea1SDimitry Andric         break;
462*0fca6ea1SDimitry Andric       case Intrinsic::lifetime_end:
463*0fca6ea1SDimitry Andric         Changed |= toSpvOverloadedIntrinsic(
464*0fca6ea1SDimitry Andric             II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
465*0fca6ea1SDimitry Andric         break;
466*0fca6ea1SDimitry Andric       case Intrinsic::ptr_annotation:
467*0fca6ea1SDimitry Andric         lowerPtrAnnotation(II);
468*0fca6ea1SDimitry Andric         Changed = true;
469*0fca6ea1SDimitry Andric         break;
47006c3fb27SDimitry Andric       }
47106c3fb27SDimitry Andric     }
47206c3fb27SDimitry Andric   }
47306c3fb27SDimitry Andric   return Changed;
47406c3fb27SDimitry Andric }
47506c3fb27SDimitry Andric 
47606c3fb27SDimitry Andric // Returns F if aggregate argument/return types are not present or cloned F
47706c3fb27SDimitry Andric // function with the types replaced by i32 types. The change in types is
47806c3fb27SDimitry Andric // noted in 'spv.cloned_funcs' metadata for later restoration.
47906c3fb27SDimitry Andric Function *
48006c3fb27SDimitry Andric SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
481fcaf7f86SDimitry Andric   IRBuilder<> B(F->getContext());
482fcaf7f86SDimitry Andric 
483fcaf7f86SDimitry Andric   bool IsRetAggr = F->getReturnType()->isAggregateType();
484fcaf7f86SDimitry Andric   bool HasAggrArg =
485fcaf7f86SDimitry Andric       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
486fcaf7f86SDimitry Andric         return Arg.getType()->isAggregateType();
487fcaf7f86SDimitry Andric       });
488fcaf7f86SDimitry Andric   bool DoClone = IsRetAggr || HasAggrArg;
489fcaf7f86SDimitry Andric   if (!DoClone)
490fcaf7f86SDimitry Andric     return F;
491fcaf7f86SDimitry Andric   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
492fcaf7f86SDimitry Andric   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
493fcaf7f86SDimitry Andric   if (IsRetAggr)
494fcaf7f86SDimitry Andric     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
495fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
496fcaf7f86SDimitry Andric   for (const auto &Arg : F->args()) {
497fcaf7f86SDimitry Andric     if (Arg.getType()->isAggregateType()) {
498fcaf7f86SDimitry Andric       ArgTypes.push_back(B.getInt32Ty());
499fcaf7f86SDimitry Andric       ChangedTypes.push_back(
500fcaf7f86SDimitry Andric           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
501fcaf7f86SDimitry Andric     } else
502fcaf7f86SDimitry Andric       ArgTypes.push_back(Arg.getType());
503fcaf7f86SDimitry Andric   }
504fcaf7f86SDimitry Andric   FunctionType *NewFTy =
505fcaf7f86SDimitry Andric       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
506fcaf7f86SDimitry Andric   Function *NewF =
507fcaf7f86SDimitry Andric       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
508fcaf7f86SDimitry Andric 
509fcaf7f86SDimitry Andric   ValueToValueMapTy VMap;
510fcaf7f86SDimitry Andric   auto NewFArgIt = NewF->arg_begin();
511fcaf7f86SDimitry Andric   for (auto &Arg : F->args()) {
512fcaf7f86SDimitry Andric     StringRef ArgName = Arg.getName();
513fcaf7f86SDimitry Andric     NewFArgIt->setName(ArgName);
514fcaf7f86SDimitry Andric     VMap[&Arg] = &(*NewFArgIt++);
515fcaf7f86SDimitry Andric   }
516fcaf7f86SDimitry Andric   SmallVector<ReturnInst *, 8> Returns;
517fcaf7f86SDimitry Andric 
518fcaf7f86SDimitry Andric   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
519fcaf7f86SDimitry Andric                     Returns);
520fcaf7f86SDimitry Andric   NewF->takeName(F);
521fcaf7f86SDimitry Andric 
522fcaf7f86SDimitry Andric   NamedMDNode *FuncMD =
523fcaf7f86SDimitry Andric       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
524fcaf7f86SDimitry Andric   SmallVector<Metadata *, 2> MDArgs;
525fcaf7f86SDimitry Andric   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
526fcaf7f86SDimitry Andric   for (auto &ChangedTyP : ChangedTypes)
527fcaf7f86SDimitry Andric     MDArgs.push_back(MDNode::get(
528fcaf7f86SDimitry Andric         B.getContext(),
529fcaf7f86SDimitry Andric         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
530fcaf7f86SDimitry Andric          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
531fcaf7f86SDimitry Andric   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
532fcaf7f86SDimitry Andric   FuncMD->addOperand(ThisFuncMD);
533fcaf7f86SDimitry Andric 
534fcaf7f86SDimitry Andric   for (auto *U : make_early_inc_range(F->users())) {
535fcaf7f86SDimitry Andric     if (auto *CI = dyn_cast<CallInst>(U))
536fcaf7f86SDimitry Andric       CI->mutateFunctionType(NewF->getFunctionType());
537fcaf7f86SDimitry Andric     U->replaceUsesOfWith(F, NewF);
538fcaf7f86SDimitry Andric   }
539*0fca6ea1SDimitry Andric 
540*0fca6ea1SDimitry Andric   // register the mutation
541*0fca6ea1SDimitry Andric   if (RetType != F->getReturnType())
542*0fca6ea1SDimitry Andric     TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
543*0fca6ea1SDimitry Andric         NewF, F->getReturnType());
544fcaf7f86SDimitry Andric   return NewF;
545fcaf7f86SDimitry Andric }
546fcaf7f86SDimitry Andric 
547fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) {
54806c3fb27SDimitry Andric   bool Changed = false;
549fcaf7f86SDimitry Andric   for (Function &F : M)
55006c3fb27SDimitry Andric     Changed |= substituteIntrinsicCalls(&F);
551fcaf7f86SDimitry Andric 
552fcaf7f86SDimitry Andric   std::vector<Function *> FuncsWorklist;
553fcaf7f86SDimitry Andric   for (auto &F : M)
554fcaf7f86SDimitry Andric     FuncsWorklist.push_back(&F);
555fcaf7f86SDimitry Andric 
55606c3fb27SDimitry Andric   for (auto *F : FuncsWorklist) {
55706c3fb27SDimitry Andric     Function *NewF = removeAggregateTypesFromSignature(F);
558fcaf7f86SDimitry Andric 
55906c3fb27SDimitry Andric     if (NewF != F) {
56006c3fb27SDimitry Andric       F->eraseFromParent();
56106c3fb27SDimitry Andric       Changed = true;
562fcaf7f86SDimitry Andric     }
563fcaf7f86SDimitry Andric   }
564fcaf7f86SDimitry Andric   return Changed;
565fcaf7f86SDimitry Andric }
566fcaf7f86SDimitry Andric 
5675f757f3fSDimitry Andric ModulePass *
5685f757f3fSDimitry Andric llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
5695f757f3fSDimitry Andric   return new SPIRVPrepareFunctions(TM);
570fcaf7f86SDimitry Andric }
571