xref: /llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (revision fe7cb156064ff59dba7c0496db3b4da39fb1a663)
1b8e1544bSIlia Diachkov //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2b8e1544bSIlia Diachkov //
3b8e1544bSIlia Diachkov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b8e1544bSIlia Diachkov // See https://llvm.org/LICENSE.txt for license information.
5b8e1544bSIlia Diachkov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b8e1544bSIlia Diachkov //
7b8e1544bSIlia Diachkov //===----------------------------------------------------------------------===//
8b8e1544bSIlia Diachkov //
9b8e1544bSIlia Diachkov // This pass modifies function signatures containing aggregate arguments
1078ea3cb0SMichal Paszkowski // and/or return value before IRTranslator. Information about the original
1178ea3cb0SMichal Paszkowski // signatures is stored in metadata. It is used during call lowering to
1278ea3cb0SMichal Paszkowski // restore correct SPIR-V types of function arguments and return values.
1378ea3cb0SMichal Paszkowski // This pass also substitutes some llvm intrinsic calls with calls to newly
1478ea3cb0SMichal Paszkowski // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15b8e1544bSIlia Diachkov //
16b8e1544bSIlia Diachkov // NOTE: this pass is a module-level one due to the necessity to modify
17b8e1544bSIlia Diachkov // GVs/functions.
18b8e1544bSIlia Diachkov //
19b8e1544bSIlia Diachkov //===----------------------------------------------------------------------===//
20b8e1544bSIlia Diachkov 
21b8e1544bSIlia Diachkov #include "SPIRV.h"
228b732658SPaulo Matos #include "SPIRVSubtarget.h"
23b8e1544bSIlia Diachkov #include "SPIRVTargetMachine.h"
24b8e1544bSIlia Diachkov #include "SPIRVUtils.h"
25*fe7cb156SVyacheslav Levytskyy #include "llvm/ADT/StringExtras.h"
26f63adf3bSVyacheslav Levytskyy #include "llvm/Analysis/ValueTracking.h"
273544d200SIlia Diachkov #include "llvm/CodeGen/IntrinsicLowering.h"
28b8e1544bSIlia Diachkov #include "llvm/IR/IRBuilder.h"
29b8e1544bSIlia Diachkov #include "llvm/IR/IntrinsicInst.h"
3005640657SPaulo Matos #include "llvm/IR/Intrinsics.h"
3105640657SPaulo Matos #include "llvm/IR/IntrinsicsSPIRV.h"
32b8e1544bSIlia Diachkov #include "llvm/Transforms/Utils/Cloning.h"
33b8e1544bSIlia Diachkov #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34f63adf3bSVyacheslav Levytskyy #include <regex>
35b8e1544bSIlia Diachkov 
36b8e1544bSIlia Diachkov using namespace llvm;
37b8e1544bSIlia Diachkov 
38b8e1544bSIlia Diachkov namespace llvm {
39b8e1544bSIlia Diachkov void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
40b8e1544bSIlia Diachkov }
41b8e1544bSIlia Diachkov 
42b8e1544bSIlia Diachkov namespace {
43b8e1544bSIlia Diachkov 
44b8e1544bSIlia Diachkov class SPIRVPrepareFunctions : public ModulePass {
458b732658SPaulo Matos   const SPIRVTargetMachine &TM;
4678ea3cb0SMichal Paszkowski   bool substituteIntrinsicCalls(Function *F);
4778ea3cb0SMichal Paszkowski   Function *removeAggregateTypesFromSignature(Function *F);
48b8e1544bSIlia Diachkov 
49b8e1544bSIlia Diachkov public:
50b8e1544bSIlia Diachkov   static char ID;
518b732658SPaulo Matos   SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
52b8e1544bSIlia Diachkov     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
53b8e1544bSIlia Diachkov   }
54b8e1544bSIlia Diachkov 
55b8e1544bSIlia Diachkov   bool runOnModule(Module &M) override;
56b8e1544bSIlia Diachkov 
57b8e1544bSIlia Diachkov   StringRef getPassName() const override { return "SPIRV prepare functions"; }
58b8e1544bSIlia Diachkov 
59b8e1544bSIlia Diachkov   void getAnalysisUsage(AnalysisUsage &AU) const override {
60b8e1544bSIlia Diachkov     ModulePass::getAnalysisUsage(AU);
61b8e1544bSIlia Diachkov   }
62b8e1544bSIlia Diachkov };
63b8e1544bSIlia Diachkov 
64b8e1544bSIlia Diachkov } // namespace
65b8e1544bSIlia Diachkov 
66b8e1544bSIlia Diachkov char SPIRVPrepareFunctions::ID = 0;
67b8e1544bSIlia Diachkov 
68b8e1544bSIlia Diachkov INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
69b8e1544bSIlia Diachkov                 "SPIRV prepare functions", false, false)
70b8e1544bSIlia Diachkov 
7178ea3cb0SMichal Paszkowski std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
7278ea3cb0SMichal Paszkowski   Function *IntrinsicFunc = II->getCalledFunction();
7378ea3cb0SMichal Paszkowski   assert(IntrinsicFunc && "Missing function");
7478ea3cb0SMichal Paszkowski   std::string FuncName = IntrinsicFunc->getName().str();
7578ea3cb0SMichal Paszkowski   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
7678ea3cb0SMichal Paszkowski   FuncName = "spirv." + FuncName;
7778ea3cb0SMichal Paszkowski   return FuncName;
7878ea3cb0SMichal Paszkowski }
7978ea3cb0SMichal Paszkowski 
8078ea3cb0SMichal Paszkowski static Function *getOrCreateFunction(Module *M, Type *RetTy,
8178ea3cb0SMichal Paszkowski                                      ArrayRef<Type *> ArgTypes,
8278ea3cb0SMichal Paszkowski                                      StringRef Name) {
8378ea3cb0SMichal Paszkowski   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
8478ea3cb0SMichal Paszkowski   Function *F = M->getFunction(Name);
8578ea3cb0SMichal Paszkowski   if (F && F->getFunctionType() == FT)
8678ea3cb0SMichal Paszkowski     return F;
8778ea3cb0SMichal Paszkowski   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
8878ea3cb0SMichal Paszkowski   if (F)
8978ea3cb0SMichal Paszkowski     NewF->setDSOLocal(F->isDSOLocal());
9078ea3cb0SMichal Paszkowski   NewF->setCallingConv(CallingConv::SPIR_FUNC);
9178ea3cb0SMichal Paszkowski   return NewF;
9278ea3cb0SMichal Paszkowski }
9378ea3cb0SMichal Paszkowski 
9478ea3cb0SMichal Paszkowski static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
9578ea3cb0SMichal Paszkowski   // For @llvm.memset.* intrinsic cases with constant value and length arguments
9678ea3cb0SMichal Paszkowski   // are emulated via "storing" a constant array to the destination. For other
9778ea3cb0SMichal Paszkowski   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
9878ea3cb0SMichal Paszkowski   // intrinsic to a loop via expandMemSetAsLoop().
9978ea3cb0SMichal Paszkowski   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
10078ea3cb0SMichal Paszkowski     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
10178ea3cb0SMichal Paszkowski       return false; // It is handled later using OpCopyMemorySized.
10278ea3cb0SMichal Paszkowski 
10378ea3cb0SMichal Paszkowski   Module *M = Intrinsic->getModule();
10478ea3cb0SMichal Paszkowski   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
10578ea3cb0SMichal Paszkowski   if (Intrinsic->isVolatile())
10678ea3cb0SMichal Paszkowski     FuncName += ".volatile";
10778ea3cb0SMichal Paszkowski   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
10878ea3cb0SMichal Paszkowski   Function *F = M->getFunction(FuncName);
10978ea3cb0SMichal Paszkowski   if (F) {
11078ea3cb0SMichal Paszkowski     Intrinsic->setCalledFunction(F);
11178ea3cb0SMichal Paszkowski     return true;
11278ea3cb0SMichal Paszkowski   }
11378ea3cb0SMichal Paszkowski   // TODO copy arguments attributes: nocapture writeonly.
11478ea3cb0SMichal Paszkowski   FunctionCallee FC =
11578ea3cb0SMichal Paszkowski       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
11678ea3cb0SMichal Paszkowski   auto IntrinsicID = Intrinsic->getIntrinsicID();
11778ea3cb0SMichal Paszkowski   Intrinsic->setCalledFunction(FC);
11878ea3cb0SMichal Paszkowski 
11978ea3cb0SMichal Paszkowski   F = dyn_cast<Function>(FC.getCallee());
12078ea3cb0SMichal Paszkowski   assert(F && "Callee must be a function");
12178ea3cb0SMichal Paszkowski 
12278ea3cb0SMichal Paszkowski   switch (IntrinsicID) {
12378ea3cb0SMichal Paszkowski   case Intrinsic::memset: {
12478ea3cb0SMichal Paszkowski     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
12578ea3cb0SMichal Paszkowski     Argument *Dest = F->getArg(0);
12678ea3cb0SMichal Paszkowski     Argument *Val = F->getArg(1);
12778ea3cb0SMichal Paszkowski     Argument *Len = F->getArg(2);
12878ea3cb0SMichal Paszkowski     Argument *IsVolatile = F->getArg(3);
12978ea3cb0SMichal Paszkowski     Dest->setName("dest");
13078ea3cb0SMichal Paszkowski     Val->setName("val");
13178ea3cb0SMichal Paszkowski     Len->setName("len");
13278ea3cb0SMichal Paszkowski     IsVolatile->setName("isvolatile");
13378ea3cb0SMichal Paszkowski     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
13478ea3cb0SMichal Paszkowski     IRBuilder<> IRB(EntryBB);
13578ea3cb0SMichal Paszkowski     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
13678ea3cb0SMichal Paszkowski                                     MSI->isVolatile());
13778ea3cb0SMichal Paszkowski     IRB.CreateRetVoid();
13878ea3cb0SMichal Paszkowski     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
13978ea3cb0SMichal Paszkowski     MemSet->eraseFromParent();
14078ea3cb0SMichal Paszkowski     break;
14178ea3cb0SMichal Paszkowski   }
14278ea3cb0SMichal Paszkowski   case Intrinsic::bswap: {
14378ea3cb0SMichal Paszkowski     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
14478ea3cb0SMichal Paszkowski     IRBuilder<> IRB(EntryBB);
14578ea3cb0SMichal Paszkowski     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
14678ea3cb0SMichal Paszkowski                                       F->getArg(0));
14778ea3cb0SMichal Paszkowski     IRB.CreateRet(BSwap);
14878ea3cb0SMichal Paszkowski     IntrinsicLowering IL(M->getDataLayout());
14978ea3cb0SMichal Paszkowski     IL.LowerIntrinsicCall(BSwap);
15078ea3cb0SMichal Paszkowski     break;
15178ea3cb0SMichal Paszkowski   }
15278ea3cb0SMichal Paszkowski   default:
15378ea3cb0SMichal Paszkowski     break;
15478ea3cb0SMichal Paszkowski   }
15578ea3cb0SMichal Paszkowski   return true;
15678ea3cb0SMichal Paszkowski }
15778ea3cb0SMichal Paszkowski 
158f63adf3bSVyacheslav Levytskyy static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
159f63adf3bSVyacheslav Levytskyy   if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))
160f63adf3bSVyacheslav Levytskyy     AnnoVal = Ref->getOperand(0);
161f63adf3bSVyacheslav Levytskyy   if (auto *Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))
162f63adf3bSVyacheslav Levytskyy     OptAnnoVal = Ref->getOperand(0);
163f63adf3bSVyacheslav Levytskyy 
164f63adf3bSVyacheslav Levytskyy   std::string Anno;
165f63adf3bSVyacheslav Levytskyy   if (auto *C = dyn_cast_or_null<Constant>(AnnoVal)) {
166f63adf3bSVyacheslav Levytskyy     StringRef Str;
167f63adf3bSVyacheslav Levytskyy     if (getConstantStringInfo(C, Str))
168f63adf3bSVyacheslav Levytskyy       Anno = Str;
169f63adf3bSVyacheslav Levytskyy   }
170f63adf3bSVyacheslav Levytskyy   // handle optional annotation parameter in a way that Khronos Translator do
171f63adf3bSVyacheslav Levytskyy   // (collect integers wrapped in a struct)
172f63adf3bSVyacheslav Levytskyy   if (auto *C = dyn_cast_or_null<Constant>(OptAnnoVal);
173f63adf3bSVyacheslav Levytskyy       C && C->getNumOperands()) {
174f63adf3bSVyacheslav Levytskyy     Value *MaybeStruct = C->getOperand(0);
175f63adf3bSVyacheslav Levytskyy     if (auto *Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {
176f63adf3bSVyacheslav Levytskyy       for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
177f63adf3bSVyacheslav Levytskyy         if (auto *CInt = dyn_cast<ConstantInt>(Struct->getOperand(I)))
178f63adf3bSVyacheslav Levytskyy           Anno += (I == 0 ? ": " : ", ") +
179f63adf3bSVyacheslav Levytskyy                   std::to_string(CInt->getType()->getIntegerBitWidth() == 1
180f63adf3bSVyacheslav Levytskyy                                      ? CInt->getZExtValue()
181f63adf3bSVyacheslav Levytskyy                                      : CInt->getSExtValue());
182f63adf3bSVyacheslav Levytskyy       }
183f63adf3bSVyacheslav Levytskyy     } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {
184f63adf3bSVyacheslav Levytskyy       // { i32 i32 ... } zeroinitializer
185f63adf3bSVyacheslav Levytskyy       for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
186f63adf3bSVyacheslav Levytskyy            I != E; ++I)
187f63adf3bSVyacheslav Levytskyy         Anno += I == 0 ? ": 0" : ", 0";
188f63adf3bSVyacheslav Levytskyy     }
189f63adf3bSVyacheslav Levytskyy   }
190f63adf3bSVyacheslav Levytskyy   return Anno;
191f63adf3bSVyacheslav Levytskyy }
192f63adf3bSVyacheslav Levytskyy 
193f63adf3bSVyacheslav Levytskyy static SmallVector<Metadata *> parseAnnotation(Value *I,
194f63adf3bSVyacheslav Levytskyy                                                const std::string &Anno,
195f63adf3bSVyacheslav Levytskyy                                                LLVMContext &Ctx,
196f63adf3bSVyacheslav Levytskyy                                                Type *Int32Ty) {
197f63adf3bSVyacheslav Levytskyy   // Try to parse the annotation string according to the following rules:
198f63adf3bSVyacheslav Levytskyy   // annotation := ({kind} | {kind:value,value,...})+
199f63adf3bSVyacheslav Levytskyy   // kind := number
200f63adf3bSVyacheslav Levytskyy   // value := number | string
201f63adf3bSVyacheslav Levytskyy   static const std::regex R(
202f63adf3bSVyacheslav Levytskyy       "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
203f63adf3bSVyacheslav Levytskyy   SmallVector<Metadata *> MDs;
204f63adf3bSVyacheslav Levytskyy   int Pos = 0;
205f63adf3bSVyacheslav Levytskyy   for (std::sregex_iterator
206f63adf3bSVyacheslav Levytskyy            It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
207f63adf3bSVyacheslav Levytskyy            ItEnd = std::sregex_iterator();
208f63adf3bSVyacheslav Levytskyy        It != ItEnd; ++It) {
209f63adf3bSVyacheslav Levytskyy     if (It->position() != Pos)
210f63adf3bSVyacheslav Levytskyy       return SmallVector<Metadata *>{};
211f63adf3bSVyacheslav Levytskyy     Pos = It->position() + It->length();
212f63adf3bSVyacheslav Levytskyy     std::smatch Match = *It;
213f63adf3bSVyacheslav Levytskyy     SmallVector<Metadata *> MDsItem;
214f63adf3bSVyacheslav Levytskyy     for (std::size_t i = 1; i < Match.size(); ++i) {
215f63adf3bSVyacheslav Levytskyy       std::ssub_match SMatch = Match[i];
216f63adf3bSVyacheslav Levytskyy       std::string Item = SMatch.str();
217f63adf3bSVyacheslav Levytskyy       if (Item.length() == 0)
218f63adf3bSVyacheslav Levytskyy         break;
219f63adf3bSVyacheslav Levytskyy       if (Item[0] == '"') {
220f63adf3bSVyacheslav Levytskyy         Item = Item.substr(1, Item.length() - 2);
221f63adf3bSVyacheslav Levytskyy         // Acceptable format of the string snippet is:
222f63adf3bSVyacheslav Levytskyy         static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
223f63adf3bSVyacheslav Levytskyy         if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {
224f63adf3bSVyacheslav Levytskyy           for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
225f63adf3bSVyacheslav Levytskyy             if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
226f63adf3bSVyacheslav Levytskyy               MDsItem.push_back(ConstantAsMetadata::get(
227f63adf3bSVyacheslav Levytskyy                   ConstantInt::get(Int32Ty, std::stoi(SubStr))));
228f63adf3bSVyacheslav Levytskyy         } else {
229f63adf3bSVyacheslav Levytskyy           MDsItem.push_back(MDString::get(Ctx, Item));
230f63adf3bSVyacheslav Levytskyy         }
231*fe7cb156SVyacheslav Levytskyy       } else if (int32_t Num; llvm::to_integer(StringRef(Item), Num, 10)) {
232f63adf3bSVyacheslav Levytskyy         MDsItem.push_back(
233f63adf3bSVyacheslav Levytskyy             ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Num)));
234f63adf3bSVyacheslav Levytskyy       } else {
235f63adf3bSVyacheslav Levytskyy         MDsItem.push_back(MDString::get(Ctx, Item));
236f63adf3bSVyacheslav Levytskyy       }
237f63adf3bSVyacheslav Levytskyy     }
238f63adf3bSVyacheslav Levytskyy     if (MDsItem.size() == 0)
239f63adf3bSVyacheslav Levytskyy       return SmallVector<Metadata *>{};
240f63adf3bSVyacheslav Levytskyy     MDs.push_back(MDNode::get(Ctx, MDsItem));
241f63adf3bSVyacheslav Levytskyy   }
242f63adf3bSVyacheslav Levytskyy   return Pos == static_cast<int>(Anno.length()) ? MDs
243f63adf3bSVyacheslav Levytskyy                                                 : SmallVector<Metadata *>{};
244f63adf3bSVyacheslav Levytskyy }
245f63adf3bSVyacheslav Levytskyy 
246f63adf3bSVyacheslav Levytskyy static void lowerPtrAnnotation(IntrinsicInst *II) {
247f63adf3bSVyacheslav Levytskyy   LLVMContext &Ctx = II->getContext();
248f63adf3bSVyacheslav Levytskyy   Type *Int32Ty = Type::getInt32Ty(Ctx);
249f63adf3bSVyacheslav Levytskyy 
250f63adf3bSVyacheslav Levytskyy   // Retrieve an annotation string from arguments.
251f63adf3bSVyacheslav Levytskyy   Value *PtrArg = nullptr;
252f63adf3bSVyacheslav Levytskyy   if (auto *BI = dyn_cast<BitCastInst>(II->getArgOperand(0)))
253f63adf3bSVyacheslav Levytskyy     PtrArg = BI->getOperand(0);
254f63adf3bSVyacheslav Levytskyy   else
255f63adf3bSVyacheslav Levytskyy     PtrArg = II->getOperand(0);
256f63adf3bSVyacheslav Levytskyy   std::string Anno =
257f63adf3bSVyacheslav Levytskyy       getAnnotation(II->getArgOperand(1),
258f63adf3bSVyacheslav Levytskyy                     4 < II->arg_size() ? II->getArgOperand(4) : nullptr);
259f63adf3bSVyacheslav Levytskyy 
260f63adf3bSVyacheslav Levytskyy   // Parse the annotation.
261f63adf3bSVyacheslav Levytskyy   SmallVector<Metadata *> MDs = parseAnnotation(II, Anno, Ctx, Int32Ty);
262f63adf3bSVyacheslav Levytskyy 
263f63adf3bSVyacheslav Levytskyy   // If the annotation string is not parsed successfully we don't know the
264f63adf3bSVyacheslav Levytskyy   // format used and output it as a general UserSemantic decoration.
265f63adf3bSVyacheslav Levytskyy   // Otherwise MDs is a Metadata tuple (a decoration list) in the format
266f63adf3bSVyacheslav Levytskyy   // expected by `spirv.Decorations`.
267f63adf3bSVyacheslav Levytskyy   if (MDs.size() == 0) {
268f63adf3bSVyacheslav Levytskyy     auto UserSemantic = ConstantAsMetadata::get(ConstantInt::get(
269f63adf3bSVyacheslav Levytskyy         Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
270f63adf3bSVyacheslav Levytskyy     MDs.push_back(MDNode::get(Ctx, {UserSemantic, MDString::get(Ctx, Anno)}));
271f63adf3bSVyacheslav Levytskyy   }
272f63adf3bSVyacheslav Levytskyy 
273f63adf3bSVyacheslav Levytskyy   // Build the internal intrinsic function.
274f63adf3bSVyacheslav Levytskyy   IRBuilder<> IRB(II->getParent());
275f63adf3bSVyacheslav Levytskyy   IRB.SetInsertPoint(II);
276f63adf3bSVyacheslav Levytskyy   IRB.CreateIntrinsic(
277f63adf3bSVyacheslav Levytskyy       Intrinsic::spv_assign_decoration, {PtrArg->getType()},
278f63adf3bSVyacheslav Levytskyy       {PtrArg, MetadataAsValue::get(Ctx, MDNode::get(Ctx, MDs))});
279f63adf3bSVyacheslav Levytskyy   II->replaceAllUsesWith(II->getOperand(0));
280f63adf3bSVyacheslav Levytskyy }
281f63adf3bSVyacheslav Levytskyy 
28278ea3cb0SMichal Paszkowski static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
28378ea3cb0SMichal Paszkowski   // Get a separate function - otherwise, we'd have to rework the CFG of the
28478ea3cb0SMichal Paszkowski   // current one. Then simply replace the intrinsic uses with a call to the new
28578ea3cb0SMichal Paszkowski   // function.
28678ea3cb0SMichal Paszkowski   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
28778ea3cb0SMichal Paszkowski   Module *M = FSHIntrinsic->getModule();
28878ea3cb0SMichal Paszkowski   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
28978ea3cb0SMichal Paszkowski   Type *FSHRetTy = FSHFuncTy->getReturnType();
29078ea3cb0SMichal Paszkowski   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
29178ea3cb0SMichal Paszkowski   Function *FSHFunc =
29278ea3cb0SMichal Paszkowski       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
29378ea3cb0SMichal Paszkowski 
29478ea3cb0SMichal Paszkowski   if (!FSHFunc->empty()) {
29578ea3cb0SMichal Paszkowski     FSHIntrinsic->setCalledFunction(FSHFunc);
29678ea3cb0SMichal Paszkowski     return;
29778ea3cb0SMichal Paszkowski   }
29878ea3cb0SMichal Paszkowski   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
29978ea3cb0SMichal Paszkowski   IRBuilder<> IRB(RotateBB);
30078ea3cb0SMichal Paszkowski   Type *Ty = FSHFunc->getReturnType();
30178ea3cb0SMichal Paszkowski   // Build the actual funnel shift rotate logic.
30278ea3cb0SMichal Paszkowski   // In the comments, "int" is used interchangeably with "vector of int
30378ea3cb0SMichal Paszkowski   // elements".
30478ea3cb0SMichal Paszkowski   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
30578ea3cb0SMichal Paszkowski   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
30678ea3cb0SMichal Paszkowski   unsigned BitWidth = IntTy->getIntegerBitWidth();
30778ea3cb0SMichal Paszkowski   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
30878ea3cb0SMichal Paszkowski   Value *BitWidthForInsts =
30978ea3cb0SMichal Paszkowski       VectorTy
31078ea3cb0SMichal Paszkowski           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
31178ea3cb0SMichal Paszkowski           : BitWidthConstant;
31278ea3cb0SMichal Paszkowski   Value *RotateModVal =
31378ea3cb0SMichal Paszkowski       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
31478ea3cb0SMichal Paszkowski   Value *FirstShift = nullptr, *SecShift = nullptr;
31578ea3cb0SMichal Paszkowski   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
31678ea3cb0SMichal Paszkowski     // Shift the less significant number right, the "rotate" number of bits
31778ea3cb0SMichal Paszkowski     // will be 0-filled on the left as a result of this regular shift.
31878ea3cb0SMichal Paszkowski     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
31978ea3cb0SMichal Paszkowski   } else {
32078ea3cb0SMichal Paszkowski     // Shift the more significant number left, the "rotate" number of bits
32178ea3cb0SMichal Paszkowski     // will be 0-filled on the right as a result of this regular shift.
32278ea3cb0SMichal Paszkowski     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
32378ea3cb0SMichal Paszkowski   }
32478ea3cb0SMichal Paszkowski   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
32578ea3cb0SMichal Paszkowski   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
32678ea3cb0SMichal Paszkowski   // Therefore, subtract the "rotate" number from the integer bitsize...
32778ea3cb0SMichal Paszkowski   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
32878ea3cb0SMichal Paszkowski   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
32978ea3cb0SMichal Paszkowski     // ...and left-shift the more significant int by this number, zero-filling
33078ea3cb0SMichal Paszkowski     // the LSBs.
33178ea3cb0SMichal Paszkowski     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
33278ea3cb0SMichal Paszkowski   } else {
33378ea3cb0SMichal Paszkowski     // ...and right-shift the less significant int by this number, zero-filling
33478ea3cb0SMichal Paszkowski     // the MSBs.
33578ea3cb0SMichal Paszkowski     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
33678ea3cb0SMichal Paszkowski   }
33778ea3cb0SMichal Paszkowski   // A simple binary addition of the shifted ints yields the final result.
33878ea3cb0SMichal Paszkowski   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
33978ea3cb0SMichal Paszkowski 
34078ea3cb0SMichal Paszkowski   FSHIntrinsic->setCalledFunction(FSHFunc);
34178ea3cb0SMichal Paszkowski }
34278ea3cb0SMichal Paszkowski 
34305640657SPaulo Matos static void lowerExpectAssume(IntrinsicInst *II) {
34405640657SPaulo Matos   // If we cannot use the SPV_KHR_expect_assume extension, then we need to
34505640657SPaulo Matos   // ignore the intrinsic and move on. It should be removed later on by LLVM.
34605640657SPaulo Matos   // Otherwise we should lower the intrinsic to the corresponding SPIR-V
34705640657SPaulo Matos   // instruction.
34805640657SPaulo Matos   // For @llvm.assume we have OpAssumeTrueKHR.
34905640657SPaulo Matos   // For @llvm.expect we have OpExpectKHR.
35005640657SPaulo Matos   //
35105640657SPaulo Matos   // We need to lower this into a builtin and then the builtin into a SPIR-V
35205640657SPaulo Matos   // instruction.
35305640657SPaulo Matos   if (II->getIntrinsicID() == Intrinsic::assume) {
354fa789dffSRahul Joshi     Function *F = Intrinsic::getOrInsertDeclaration(
35505640657SPaulo Matos         II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
35605640657SPaulo Matos     II->setCalledFunction(F);
35705640657SPaulo Matos   } else if (II->getIntrinsicID() == Intrinsic::expect) {
358fa789dffSRahul Joshi     Function *F = Intrinsic::getOrInsertDeclaration(
35905640657SPaulo Matos         II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
36005640657SPaulo Matos         {II->getOperand(0)->getType()});
36105640657SPaulo Matos     II->setCalledFunction(F);
36205640657SPaulo Matos   } else {
36305640657SPaulo Matos     llvm_unreachable("Unknown intrinsic");
36405640657SPaulo Matos   }
36505640657SPaulo Matos 
36605640657SPaulo Matos   return;
36705640657SPaulo Matos }
36805640657SPaulo Matos 
36959f34e8cSVyacheslav Levytskyy static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
37059f34e8cSVyacheslav Levytskyy                                      ArrayRef<unsigned> OpNos) {
37159f34e8cSVyacheslav Levytskyy   Function *F = nullptr;
37259f34e8cSVyacheslav Levytskyy   if (OpNos.empty()) {
373fa789dffSRahul Joshi     F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID);
37459f34e8cSVyacheslav Levytskyy   } else {
37559f34e8cSVyacheslav Levytskyy     SmallVector<Type *, 4> Tys;
37659f34e8cSVyacheslav Levytskyy     for (unsigned OpNo : OpNos)
37759f34e8cSVyacheslav Levytskyy       Tys.push_back(II->getOperand(OpNo)->getType());
378fa789dffSRahul Joshi     F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys);
37959f34e8cSVyacheslav Levytskyy   }
38059f34e8cSVyacheslav Levytskyy   II->setCalledFunction(F);
38159f34e8cSVyacheslav Levytskyy   return true;
38259f34e8cSVyacheslav Levytskyy }
38359f34e8cSVyacheslav Levytskyy 
38478ea3cb0SMichal Paszkowski // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
38578ea3cb0SMichal Paszkowski // or calls to proper generated functions. Returns True if F was modified.
38678ea3cb0SMichal Paszkowski bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
38778ea3cb0SMichal Paszkowski   bool Changed = false;
38878ea3cb0SMichal Paszkowski   for (BasicBlock &BB : *F) {
38978ea3cb0SMichal Paszkowski     for (Instruction &I : BB) {
39078ea3cb0SMichal Paszkowski       auto Call = dyn_cast<CallInst>(&I);
39178ea3cb0SMichal Paszkowski       if (!Call)
39278ea3cb0SMichal Paszkowski         continue;
39378ea3cb0SMichal Paszkowski       Function *CF = Call->getCalledFunction();
39478ea3cb0SMichal Paszkowski       if (!CF || !CF->isIntrinsic())
39578ea3cb0SMichal Paszkowski         continue;
39678ea3cb0SMichal Paszkowski       auto *II = cast<IntrinsicInst>(Call);
39759f34e8cSVyacheslav Levytskyy       switch (II->getIntrinsicID()) {
39859f34e8cSVyacheslav Levytskyy       case Intrinsic::memset:
39959f34e8cSVyacheslav Levytskyy       case Intrinsic::bswap:
40078ea3cb0SMichal Paszkowski         Changed |= lowerIntrinsicToFunction(II);
40159f34e8cSVyacheslav Levytskyy         break;
40259f34e8cSVyacheslav Levytskyy       case Intrinsic::fshl:
40359f34e8cSVyacheslav Levytskyy       case Intrinsic::fshr:
40478ea3cb0SMichal Paszkowski         lowerFunnelShifts(II);
40578ea3cb0SMichal Paszkowski         Changed = true;
40659f34e8cSVyacheslav Levytskyy         break;
40759f34e8cSVyacheslav Levytskyy       case Intrinsic::assume:
40859f34e8cSVyacheslav Levytskyy       case Intrinsic::expect: {
4098b732658SPaulo Matos         const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
4108b732658SPaulo Matos         if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
41105640657SPaulo Matos           lowerExpectAssume(II);
41205640657SPaulo Matos         Changed = true;
41359f34e8cSVyacheslav Levytskyy       } break;
41459f34e8cSVyacheslav Levytskyy       case Intrinsic::lifetime_start:
41559f34e8cSVyacheslav Levytskyy         Changed |= toSpvOverloadedIntrinsic(
41659f34e8cSVyacheslav Levytskyy             II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
41759f34e8cSVyacheslav Levytskyy         break;
41859f34e8cSVyacheslav Levytskyy       case Intrinsic::lifetime_end:
41959f34e8cSVyacheslav Levytskyy         Changed |= toSpvOverloadedIntrinsic(
42059f34e8cSVyacheslav Levytskyy             II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
42159f34e8cSVyacheslav Levytskyy         break;
422f63adf3bSVyacheslav Levytskyy       case Intrinsic::ptr_annotation:
423f63adf3bSVyacheslav Levytskyy         lowerPtrAnnotation(II);
424f63adf3bSVyacheslav Levytskyy         Changed = true;
425f63adf3bSVyacheslav Levytskyy         break;
42678ea3cb0SMichal Paszkowski       }
42778ea3cb0SMichal Paszkowski     }
42878ea3cb0SMichal Paszkowski   }
42978ea3cb0SMichal Paszkowski   return Changed;
43078ea3cb0SMichal Paszkowski }
43178ea3cb0SMichal Paszkowski 
43278ea3cb0SMichal Paszkowski // Returns F if aggregate argument/return types are not present or cloned F
43378ea3cb0SMichal Paszkowski // function with the types replaced by i32 types. The change in types is
43478ea3cb0SMichal Paszkowski // noted in 'spv.cloned_funcs' metadata for later restoration.
43578ea3cb0SMichal Paszkowski Function *
43678ea3cb0SMichal Paszkowski SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
437a059b299SVyacheslav Levytskyy   bool IsRetAggr = F->getReturnType()->isAggregateType();
438a059b299SVyacheslav Levytskyy   // Allow intrinsics with aggregate return type to reach GlobalISel
439a059b299SVyacheslav Levytskyy   if (F->isIntrinsic() && IsRetAggr)
440a059b299SVyacheslav Levytskyy     return F;
441a059b299SVyacheslav Levytskyy 
442b8e1544bSIlia Diachkov   IRBuilder<> B(F->getContext());
443b8e1544bSIlia Diachkov 
444b8e1544bSIlia Diachkov   bool HasAggrArg =
445b8e1544bSIlia Diachkov       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
446b8e1544bSIlia Diachkov         return Arg.getType()->isAggregateType();
447b8e1544bSIlia Diachkov       });
448b8e1544bSIlia Diachkov   bool DoClone = IsRetAggr || HasAggrArg;
449b8e1544bSIlia Diachkov   if (!DoClone)
450b8e1544bSIlia Diachkov     return F;
451b8e1544bSIlia Diachkov   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
452b8e1544bSIlia Diachkov   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
453b8e1544bSIlia Diachkov   if (IsRetAggr)
454b8e1544bSIlia Diachkov     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
455b8e1544bSIlia Diachkov   SmallVector<Type *, 4> ArgTypes;
456b8e1544bSIlia Diachkov   for (const auto &Arg : F->args()) {
457b8e1544bSIlia Diachkov     if (Arg.getType()->isAggregateType()) {
458b8e1544bSIlia Diachkov       ArgTypes.push_back(B.getInt32Ty());
459b8e1544bSIlia Diachkov       ChangedTypes.push_back(
460b8e1544bSIlia Diachkov           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
461b8e1544bSIlia Diachkov     } else
462b8e1544bSIlia Diachkov       ArgTypes.push_back(Arg.getType());
463b8e1544bSIlia Diachkov   }
464b8e1544bSIlia Diachkov   FunctionType *NewFTy =
465b8e1544bSIlia Diachkov       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
466b8e1544bSIlia Diachkov   Function *NewF =
467b8e1544bSIlia Diachkov       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
468b8e1544bSIlia Diachkov 
469b8e1544bSIlia Diachkov   ValueToValueMapTy VMap;
470b8e1544bSIlia Diachkov   auto NewFArgIt = NewF->arg_begin();
471b8e1544bSIlia Diachkov   for (auto &Arg : F->args()) {
472b8e1544bSIlia Diachkov     StringRef ArgName = Arg.getName();
473b8e1544bSIlia Diachkov     NewFArgIt->setName(ArgName);
474b8e1544bSIlia Diachkov     VMap[&Arg] = &(*NewFArgIt++);
475b8e1544bSIlia Diachkov   }
476b8e1544bSIlia Diachkov   SmallVector<ReturnInst *, 8> Returns;
477b8e1544bSIlia Diachkov 
478b8e1544bSIlia Diachkov   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
479b8e1544bSIlia Diachkov                     Returns);
480b8e1544bSIlia Diachkov   NewF->takeName(F);
481b8e1544bSIlia Diachkov 
482b8e1544bSIlia Diachkov   NamedMDNode *FuncMD =
483b8e1544bSIlia Diachkov       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
484b8e1544bSIlia Diachkov   SmallVector<Metadata *, 2> MDArgs;
485b8e1544bSIlia Diachkov   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
486b8e1544bSIlia Diachkov   for (auto &ChangedTyP : ChangedTypes)
487b8e1544bSIlia Diachkov     MDArgs.push_back(MDNode::get(
488b8e1544bSIlia Diachkov         B.getContext(),
489b8e1544bSIlia Diachkov         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
490b8e1544bSIlia Diachkov          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
491b8e1544bSIlia Diachkov   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
492b8e1544bSIlia Diachkov   FuncMD->addOperand(ThisFuncMD);
493b8e1544bSIlia Diachkov 
494b8e1544bSIlia Diachkov   for (auto *U : make_early_inc_range(F->users())) {
495b8e1544bSIlia Diachkov     if (auto *CI = dyn_cast<CallInst>(U))
496b8e1544bSIlia Diachkov       CI->mutateFunctionType(NewF->getFunctionType());
497b8e1544bSIlia Diachkov     U->replaceUsesOfWith(F, NewF);
498b8e1544bSIlia Diachkov   }
499dbd00a59SVyacheslav Levytskyy 
500dbd00a59SVyacheslav Levytskyy   // register the mutation
501dbd00a59SVyacheslav Levytskyy   if (RetType != F->getReturnType())
502dbd00a59SVyacheslav Levytskyy     TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
503dbd00a59SVyacheslav Levytskyy         NewF, F->getReturnType());
504b8e1544bSIlia Diachkov   return NewF;
505b8e1544bSIlia Diachkov }
506b8e1544bSIlia Diachkov 
507b8e1544bSIlia Diachkov bool SPIRVPrepareFunctions::runOnModule(Module &M) {
50878ea3cb0SMichal Paszkowski   bool Changed = false;
5091ed65febSNathan Gauër   for (Function &F : M) {
51078ea3cb0SMichal Paszkowski     Changed |= substituteIntrinsicCalls(&F);
5111ed65febSNathan Gauër     Changed |= sortBlocks(F);
5121ed65febSNathan Gauër   }
513b8e1544bSIlia Diachkov 
514b8e1544bSIlia Diachkov   std::vector<Function *> FuncsWorklist;
515b8e1544bSIlia Diachkov   for (auto &F : M)
516b8e1544bSIlia Diachkov     FuncsWorklist.push_back(&F);
517b8e1544bSIlia Diachkov 
51878ea3cb0SMichal Paszkowski   for (auto *F : FuncsWorklist) {
51978ea3cb0SMichal Paszkowski     Function *NewF = removeAggregateTypesFromSignature(F);
520b8e1544bSIlia Diachkov 
52178ea3cb0SMichal Paszkowski     if (NewF != F) {
52278ea3cb0SMichal Paszkowski       F->eraseFromParent();
52378ea3cb0SMichal Paszkowski       Changed = true;
524b8e1544bSIlia Diachkov     }
525b8e1544bSIlia Diachkov   }
526b8e1544bSIlia Diachkov   return Changed;
527b8e1544bSIlia Diachkov }
528b8e1544bSIlia Diachkov 
5298b732658SPaulo Matos ModulePass *
5308b732658SPaulo Matos llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
5318b732658SPaulo Matos   return new SPIRVPrepareFunctions(TM);
532b8e1544bSIlia Diachkov }
533