1fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2fcaf7f86SDimitry Andric // 3fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fcaf7f86SDimitry Andric // 7fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===// 8fcaf7f86SDimitry Andric // 9fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments 1006c3fb27SDimitry Andric // and/or return value before IRTranslator. Information about the original 1106c3fb27SDimitry Andric // signatures is stored in metadata. It is used during call lowering to 1206c3fb27SDimitry Andric // restore correct SPIR-V types of function arguments and return values. 1306c3fb27SDimitry Andric // This pass also substitutes some llvm intrinsic calls with calls to newly 1406c3fb27SDimitry Andric // generated functions (as the Khronos LLVM/SPIR-V Translator does). 15fcaf7f86SDimitry Andric // 16fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify 17fcaf7f86SDimitry Andric // GVs/functions. 18fcaf7f86SDimitry Andric // 19fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===// 20fcaf7f86SDimitry Andric 21fcaf7f86SDimitry Andric #include "SPIRV.h" 22*5f757f3fSDimitry Andric #include "SPIRVSubtarget.h" 23fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h" 24fcaf7f86SDimitry Andric #include "SPIRVUtils.h" 25bdd1243dSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h" 26fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h" 27fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 28*5f757f3fSDimitry Andric #include "llvm/IR/Intrinsics.h" 29*5f757f3fSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h" 30fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h" 31fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 32fcaf7f86SDimitry Andric 33fcaf7f86SDimitry Andric using namespace llvm; 34fcaf7f86SDimitry Andric 35fcaf7f86SDimitry Andric namespace llvm { 36fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 37fcaf7f86SDimitry Andric } 38fcaf7f86SDimitry Andric 39fcaf7f86SDimitry Andric namespace { 40fcaf7f86SDimitry Andric 41fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass { 42*5f757f3fSDimitry Andric const SPIRVTargetMachine &TM; 4306c3fb27SDimitry Andric bool substituteIntrinsicCalls(Function *F); 4406c3fb27SDimitry Andric Function *removeAggregateTypesFromSignature(Function *F); 45fcaf7f86SDimitry Andric 46fcaf7f86SDimitry Andric public: 47fcaf7f86SDimitry Andric static char ID; 48*5f757f3fSDimitry Andric SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) { 49fcaf7f86SDimitry Andric initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 50fcaf7f86SDimitry Andric } 51fcaf7f86SDimitry Andric 52fcaf7f86SDimitry Andric bool runOnModule(Module &M) override; 53fcaf7f86SDimitry Andric 54fcaf7f86SDimitry Andric StringRef getPassName() const override { return "SPIRV prepare functions"; } 55fcaf7f86SDimitry Andric 56fcaf7f86SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 57fcaf7f86SDimitry Andric ModulePass::getAnalysisUsage(AU); 58fcaf7f86SDimitry Andric } 59fcaf7f86SDimitry Andric }; 60fcaf7f86SDimitry Andric 61fcaf7f86SDimitry Andric } // namespace 62fcaf7f86SDimitry Andric 63fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0; 64fcaf7f86SDimitry Andric 65fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 66fcaf7f86SDimitry Andric "SPIRV prepare functions", false, false) 67fcaf7f86SDimitry Andric 6806c3fb27SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 6906c3fb27SDimitry Andric Function *IntrinsicFunc = II->getCalledFunction(); 7006c3fb27SDimitry Andric assert(IntrinsicFunc && "Missing function"); 7106c3fb27SDimitry Andric std::string FuncName = IntrinsicFunc->getName().str(); 7206c3fb27SDimitry Andric std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 7306c3fb27SDimitry Andric FuncName = "spirv." + FuncName; 7406c3fb27SDimitry Andric return FuncName; 7506c3fb27SDimitry Andric } 7606c3fb27SDimitry Andric 7706c3fb27SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy, 7806c3fb27SDimitry Andric ArrayRef<Type *> ArgTypes, 7906c3fb27SDimitry Andric StringRef Name) { 8006c3fb27SDimitry Andric FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 8106c3fb27SDimitry Andric Function *F = M->getFunction(Name); 8206c3fb27SDimitry Andric if (F && F->getFunctionType() == FT) 8306c3fb27SDimitry Andric return F; 8406c3fb27SDimitry Andric Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 8506c3fb27SDimitry Andric if (F) 8606c3fb27SDimitry Andric NewF->setDSOLocal(F->isDSOLocal()); 8706c3fb27SDimitry Andric NewF->setCallingConv(CallingConv::SPIR_FUNC); 8806c3fb27SDimitry Andric return NewF; 8906c3fb27SDimitry Andric } 9006c3fb27SDimitry Andric 9106c3fb27SDimitry Andric static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { 9206c3fb27SDimitry Andric // For @llvm.memset.* intrinsic cases with constant value and length arguments 9306c3fb27SDimitry Andric // are emulated via "storing" a constant array to the destination. For other 9406c3fb27SDimitry Andric // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the 9506c3fb27SDimitry Andric // intrinsic to a loop via expandMemSetAsLoop(). 9606c3fb27SDimitry Andric if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic)) 9706c3fb27SDimitry Andric if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength())) 9806c3fb27SDimitry Andric return false; // It is handled later using OpCopyMemorySized. 9906c3fb27SDimitry Andric 10006c3fb27SDimitry Andric Module *M = Intrinsic->getModule(); 10106c3fb27SDimitry Andric std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); 10206c3fb27SDimitry Andric if (Intrinsic->isVolatile()) 10306c3fb27SDimitry Andric FuncName += ".volatile"; 10406c3fb27SDimitry Andric // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* 10506c3fb27SDimitry Andric Function *F = M->getFunction(FuncName); 10606c3fb27SDimitry Andric if (F) { 10706c3fb27SDimitry Andric Intrinsic->setCalledFunction(F); 10806c3fb27SDimitry Andric return true; 10906c3fb27SDimitry Andric } 11006c3fb27SDimitry Andric // TODO copy arguments attributes: nocapture writeonly. 11106c3fb27SDimitry Andric FunctionCallee FC = 11206c3fb27SDimitry Andric M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType()); 11306c3fb27SDimitry Andric auto IntrinsicID = Intrinsic->getIntrinsicID(); 11406c3fb27SDimitry Andric Intrinsic->setCalledFunction(FC); 11506c3fb27SDimitry Andric 11606c3fb27SDimitry Andric F = dyn_cast<Function>(FC.getCallee()); 11706c3fb27SDimitry Andric assert(F && "Callee must be a function"); 11806c3fb27SDimitry Andric 11906c3fb27SDimitry Andric switch (IntrinsicID) { 12006c3fb27SDimitry Andric case Intrinsic::memset: { 12106c3fb27SDimitry Andric auto *MSI = static_cast<MemSetInst *>(Intrinsic); 12206c3fb27SDimitry Andric Argument *Dest = F->getArg(0); 12306c3fb27SDimitry Andric Argument *Val = F->getArg(1); 12406c3fb27SDimitry Andric Argument *Len = F->getArg(2); 12506c3fb27SDimitry Andric Argument *IsVolatile = F->getArg(3); 12606c3fb27SDimitry Andric Dest->setName("dest"); 12706c3fb27SDimitry Andric Val->setName("val"); 12806c3fb27SDimitry Andric Len->setName("len"); 12906c3fb27SDimitry Andric IsVolatile->setName("isvolatile"); 13006c3fb27SDimitry Andric BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 13106c3fb27SDimitry Andric IRBuilder<> IRB(EntryBB); 13206c3fb27SDimitry Andric auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), 13306c3fb27SDimitry Andric MSI->isVolatile()); 13406c3fb27SDimitry Andric IRB.CreateRetVoid(); 13506c3fb27SDimitry Andric expandMemSetAsLoop(cast<MemSetInst>(MemSet)); 13606c3fb27SDimitry Andric MemSet->eraseFromParent(); 13706c3fb27SDimitry Andric break; 13806c3fb27SDimitry Andric } 13906c3fb27SDimitry Andric case Intrinsic::bswap: { 14006c3fb27SDimitry Andric BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); 14106c3fb27SDimitry Andric IRBuilder<> IRB(EntryBB); 14206c3fb27SDimitry Andric auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), 14306c3fb27SDimitry Andric F->getArg(0)); 14406c3fb27SDimitry Andric IRB.CreateRet(BSwap); 14506c3fb27SDimitry Andric IntrinsicLowering IL(M->getDataLayout()); 14606c3fb27SDimitry Andric IL.LowerIntrinsicCall(BSwap); 14706c3fb27SDimitry Andric break; 14806c3fb27SDimitry Andric } 14906c3fb27SDimitry Andric default: 15006c3fb27SDimitry Andric break; 15106c3fb27SDimitry Andric } 15206c3fb27SDimitry Andric return true; 15306c3fb27SDimitry Andric } 15406c3fb27SDimitry Andric 15506c3fb27SDimitry Andric static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { 15606c3fb27SDimitry Andric // Get a separate function - otherwise, we'd have to rework the CFG of the 15706c3fb27SDimitry Andric // current one. Then simply replace the intrinsic uses with a call to the new 15806c3fb27SDimitry Andric // function. 15906c3fb27SDimitry Andric // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 16006c3fb27SDimitry Andric Module *M = FSHIntrinsic->getModule(); 16106c3fb27SDimitry Andric FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 16206c3fb27SDimitry Andric Type *FSHRetTy = FSHFuncTy->getReturnType(); 16306c3fb27SDimitry Andric const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 16406c3fb27SDimitry Andric Function *FSHFunc = 16506c3fb27SDimitry Andric getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 16606c3fb27SDimitry Andric 16706c3fb27SDimitry Andric if (!FSHFunc->empty()) { 16806c3fb27SDimitry Andric FSHIntrinsic->setCalledFunction(FSHFunc); 16906c3fb27SDimitry Andric return; 17006c3fb27SDimitry Andric } 17106c3fb27SDimitry Andric BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 17206c3fb27SDimitry Andric IRBuilder<> IRB(RotateBB); 17306c3fb27SDimitry Andric Type *Ty = FSHFunc->getReturnType(); 17406c3fb27SDimitry Andric // Build the actual funnel shift rotate logic. 17506c3fb27SDimitry Andric // In the comments, "int" is used interchangeably with "vector of int 17606c3fb27SDimitry Andric // elements". 17706c3fb27SDimitry Andric FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 17806c3fb27SDimitry Andric Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 17906c3fb27SDimitry Andric unsigned BitWidth = IntTy->getIntegerBitWidth(); 18006c3fb27SDimitry Andric ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 18106c3fb27SDimitry Andric Value *BitWidthForInsts = 18206c3fb27SDimitry Andric VectorTy 18306c3fb27SDimitry Andric ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 18406c3fb27SDimitry Andric : BitWidthConstant; 18506c3fb27SDimitry Andric Value *RotateModVal = 18606c3fb27SDimitry Andric IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 18706c3fb27SDimitry Andric Value *FirstShift = nullptr, *SecShift = nullptr; 18806c3fb27SDimitry Andric if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 18906c3fb27SDimitry Andric // Shift the less significant number right, the "rotate" number of bits 19006c3fb27SDimitry Andric // will be 0-filled on the left as a result of this regular shift. 19106c3fb27SDimitry Andric FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 19206c3fb27SDimitry Andric } else { 19306c3fb27SDimitry Andric // Shift the more significant number left, the "rotate" number of bits 19406c3fb27SDimitry Andric // will be 0-filled on the right as a result of this regular shift. 19506c3fb27SDimitry Andric FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 19606c3fb27SDimitry Andric } 19706c3fb27SDimitry Andric // We want the "rotate" number of the more significant int's LSBs (MSBs) to 19806c3fb27SDimitry Andric // occupy the leftmost (rightmost) "0 space" left by the previous operation. 19906c3fb27SDimitry Andric // Therefore, subtract the "rotate" number from the integer bitsize... 20006c3fb27SDimitry Andric Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 20106c3fb27SDimitry Andric if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 20206c3fb27SDimitry Andric // ...and left-shift the more significant int by this number, zero-filling 20306c3fb27SDimitry Andric // the LSBs. 20406c3fb27SDimitry Andric SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 20506c3fb27SDimitry Andric } else { 20606c3fb27SDimitry Andric // ...and right-shift the less significant int by this number, zero-filling 20706c3fb27SDimitry Andric // the MSBs. 20806c3fb27SDimitry Andric SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 20906c3fb27SDimitry Andric } 21006c3fb27SDimitry Andric // A simple binary addition of the shifted ints yields the final result. 21106c3fb27SDimitry Andric IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 21206c3fb27SDimitry Andric 21306c3fb27SDimitry Andric FSHIntrinsic->setCalledFunction(FSHFunc); 21406c3fb27SDimitry Andric } 21506c3fb27SDimitry Andric 21606c3fb27SDimitry Andric static void buildUMulWithOverflowFunc(Function *UMulFunc) { 21706c3fb27SDimitry Andric // The function body is already created. 21806c3fb27SDimitry Andric if (!UMulFunc->empty()) 21906c3fb27SDimitry Andric return; 22006c3fb27SDimitry Andric 22106c3fb27SDimitry Andric BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(), 22206c3fb27SDimitry Andric "entry", UMulFunc); 22306c3fb27SDimitry Andric IRBuilder<> IRB(EntryBB); 22406c3fb27SDimitry Andric // Build the actual unsigned multiplication logic with the overflow 22506c3fb27SDimitry Andric // indication. Do unsigned multiplication Mul = A * B. Then check 22606c3fb27SDimitry Andric // if unsigned division Div = Mul / A is not equal to B. If so, 22706c3fb27SDimitry Andric // then overflow has happened. 22806c3fb27SDimitry Andric Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 22906c3fb27SDimitry Andric Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 23006c3fb27SDimitry Andric Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 23106c3fb27SDimitry Andric 23206c3fb27SDimitry Andric // umul.with.overflow intrinsic return a structure, where the first element 23306c3fb27SDimitry Andric // is the multiplication result, and the second is an overflow bit. 23406c3fb27SDimitry Andric Type *StructTy = UMulFunc->getReturnType(); 23506c3fb27SDimitry Andric Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0}); 23606c3fb27SDimitry Andric Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 23706c3fb27SDimitry Andric IRB.CreateRet(Res); 23806c3fb27SDimitry Andric } 23906c3fb27SDimitry Andric 240*5f757f3fSDimitry Andric static void lowerExpectAssume(IntrinsicInst *II) { 241*5f757f3fSDimitry Andric // If we cannot use the SPV_KHR_expect_assume extension, then we need to 242*5f757f3fSDimitry Andric // ignore the intrinsic and move on. It should be removed later on by LLVM. 243*5f757f3fSDimitry Andric // Otherwise we should lower the intrinsic to the corresponding SPIR-V 244*5f757f3fSDimitry Andric // instruction. 245*5f757f3fSDimitry Andric // For @llvm.assume we have OpAssumeTrueKHR. 246*5f757f3fSDimitry Andric // For @llvm.expect we have OpExpectKHR. 247*5f757f3fSDimitry Andric // 248*5f757f3fSDimitry Andric // We need to lower this into a builtin and then the builtin into a SPIR-V 249*5f757f3fSDimitry Andric // instruction. 250*5f757f3fSDimitry Andric if (II->getIntrinsicID() == Intrinsic::assume) { 251*5f757f3fSDimitry Andric Function *F = Intrinsic::getDeclaration( 252*5f757f3fSDimitry Andric II->getModule(), Intrinsic::SPVIntrinsics::spv_assume); 253*5f757f3fSDimitry Andric II->setCalledFunction(F); 254*5f757f3fSDimitry Andric } else if (II->getIntrinsicID() == Intrinsic::expect) { 255*5f757f3fSDimitry Andric Function *F = Intrinsic::getDeclaration( 256*5f757f3fSDimitry Andric II->getModule(), Intrinsic::SPVIntrinsics::spv_expect, 257*5f757f3fSDimitry Andric {II->getOperand(0)->getType()}); 258*5f757f3fSDimitry Andric II->setCalledFunction(F); 259*5f757f3fSDimitry Andric } else { 260*5f757f3fSDimitry Andric llvm_unreachable("Unknown intrinsic"); 261*5f757f3fSDimitry Andric } 262*5f757f3fSDimitry Andric 263*5f757f3fSDimitry Andric return; 264*5f757f3fSDimitry Andric } 265*5f757f3fSDimitry Andric 26606c3fb27SDimitry Andric static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) { 26706c3fb27SDimitry Andric // Get a separate function - otherwise, we'd have to rework the CFG of the 26806c3fb27SDimitry Andric // current one. Then simply replace the intrinsic uses with a call to the new 26906c3fb27SDimitry Andric // function. 27006c3fb27SDimitry Andric Module *M = UMulIntrinsic->getModule(); 27106c3fb27SDimitry Andric FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 27206c3fb27SDimitry Andric Type *FSHLRetTy = UMulFuncTy->getReturnType(); 27306c3fb27SDimitry Andric const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 27406c3fb27SDimitry Andric Function *UMulFunc = 27506c3fb27SDimitry Andric getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 27606c3fb27SDimitry Andric buildUMulWithOverflowFunc(UMulFunc); 27706c3fb27SDimitry Andric UMulIntrinsic->setCalledFunction(UMulFunc); 27806c3fb27SDimitry Andric } 27906c3fb27SDimitry Andric 28006c3fb27SDimitry Andric // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics 28106c3fb27SDimitry Andric // or calls to proper generated functions. Returns True if F was modified. 28206c3fb27SDimitry Andric bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { 28306c3fb27SDimitry Andric bool Changed = false; 28406c3fb27SDimitry Andric for (BasicBlock &BB : *F) { 28506c3fb27SDimitry Andric for (Instruction &I : BB) { 28606c3fb27SDimitry Andric auto Call = dyn_cast<CallInst>(&I); 28706c3fb27SDimitry Andric if (!Call) 28806c3fb27SDimitry Andric continue; 28906c3fb27SDimitry Andric Function *CF = Call->getCalledFunction(); 29006c3fb27SDimitry Andric if (!CF || !CF->isIntrinsic()) 29106c3fb27SDimitry Andric continue; 29206c3fb27SDimitry Andric auto *II = cast<IntrinsicInst>(Call); 29306c3fb27SDimitry Andric if (II->getIntrinsicID() == Intrinsic::memset || 29406c3fb27SDimitry Andric II->getIntrinsicID() == Intrinsic::bswap) 29506c3fb27SDimitry Andric Changed |= lowerIntrinsicToFunction(II); 29606c3fb27SDimitry Andric else if (II->getIntrinsicID() == Intrinsic::fshl || 29706c3fb27SDimitry Andric II->getIntrinsicID() == Intrinsic::fshr) { 29806c3fb27SDimitry Andric lowerFunnelShifts(II); 29906c3fb27SDimitry Andric Changed = true; 30006c3fb27SDimitry Andric } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) { 30106c3fb27SDimitry Andric lowerUMulWithOverflow(II); 30206c3fb27SDimitry Andric Changed = true; 303*5f757f3fSDimitry Andric } else if (II->getIntrinsicID() == Intrinsic::assume || 304*5f757f3fSDimitry Andric II->getIntrinsicID() == Intrinsic::expect) { 305*5f757f3fSDimitry Andric const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F); 306*5f757f3fSDimitry Andric if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) 307*5f757f3fSDimitry Andric lowerExpectAssume(II); 308*5f757f3fSDimitry Andric Changed = true; 30906c3fb27SDimitry Andric } 31006c3fb27SDimitry Andric } 31106c3fb27SDimitry Andric } 31206c3fb27SDimitry Andric return Changed; 31306c3fb27SDimitry Andric } 31406c3fb27SDimitry Andric 31506c3fb27SDimitry Andric // Returns F if aggregate argument/return types are not present or cloned F 31606c3fb27SDimitry Andric // function with the types replaced by i32 types. The change in types is 31706c3fb27SDimitry Andric // noted in 'spv.cloned_funcs' metadata for later restoration. 31806c3fb27SDimitry Andric Function * 31906c3fb27SDimitry Andric SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { 320fcaf7f86SDimitry Andric IRBuilder<> B(F->getContext()); 321fcaf7f86SDimitry Andric 322fcaf7f86SDimitry Andric bool IsRetAggr = F->getReturnType()->isAggregateType(); 323fcaf7f86SDimitry Andric bool HasAggrArg = 324fcaf7f86SDimitry Andric std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 325fcaf7f86SDimitry Andric return Arg.getType()->isAggregateType(); 326fcaf7f86SDimitry Andric }); 327fcaf7f86SDimitry Andric bool DoClone = IsRetAggr || HasAggrArg; 328fcaf7f86SDimitry Andric if (!DoClone) 329fcaf7f86SDimitry Andric return F; 330fcaf7f86SDimitry Andric SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 331fcaf7f86SDimitry Andric Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 332fcaf7f86SDimitry Andric if (IsRetAggr) 333fcaf7f86SDimitry Andric ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 334fcaf7f86SDimitry Andric SmallVector<Type *, 4> ArgTypes; 335fcaf7f86SDimitry Andric for (const auto &Arg : F->args()) { 336fcaf7f86SDimitry Andric if (Arg.getType()->isAggregateType()) { 337fcaf7f86SDimitry Andric ArgTypes.push_back(B.getInt32Ty()); 338fcaf7f86SDimitry Andric ChangedTypes.push_back( 339fcaf7f86SDimitry Andric std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 340fcaf7f86SDimitry Andric } else 341fcaf7f86SDimitry Andric ArgTypes.push_back(Arg.getType()); 342fcaf7f86SDimitry Andric } 343fcaf7f86SDimitry Andric FunctionType *NewFTy = 344fcaf7f86SDimitry Andric FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 345fcaf7f86SDimitry Andric Function *NewF = 346fcaf7f86SDimitry Andric Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 347fcaf7f86SDimitry Andric 348fcaf7f86SDimitry Andric ValueToValueMapTy VMap; 349fcaf7f86SDimitry Andric auto NewFArgIt = NewF->arg_begin(); 350fcaf7f86SDimitry Andric for (auto &Arg : F->args()) { 351fcaf7f86SDimitry Andric StringRef ArgName = Arg.getName(); 352fcaf7f86SDimitry Andric NewFArgIt->setName(ArgName); 353fcaf7f86SDimitry Andric VMap[&Arg] = &(*NewFArgIt++); 354fcaf7f86SDimitry Andric } 355fcaf7f86SDimitry Andric SmallVector<ReturnInst *, 8> Returns; 356fcaf7f86SDimitry Andric 357fcaf7f86SDimitry Andric CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 358fcaf7f86SDimitry Andric Returns); 359fcaf7f86SDimitry Andric NewF->takeName(F); 360fcaf7f86SDimitry Andric 361fcaf7f86SDimitry Andric NamedMDNode *FuncMD = 362fcaf7f86SDimitry Andric F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 363fcaf7f86SDimitry Andric SmallVector<Metadata *, 2> MDArgs; 364fcaf7f86SDimitry Andric MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 365fcaf7f86SDimitry Andric for (auto &ChangedTyP : ChangedTypes) 366fcaf7f86SDimitry Andric MDArgs.push_back(MDNode::get( 367fcaf7f86SDimitry Andric B.getContext(), 368fcaf7f86SDimitry Andric {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 369fcaf7f86SDimitry Andric ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 370fcaf7f86SDimitry Andric MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 371fcaf7f86SDimitry Andric FuncMD->addOperand(ThisFuncMD); 372fcaf7f86SDimitry Andric 373fcaf7f86SDimitry Andric for (auto *U : make_early_inc_range(F->users())) { 374fcaf7f86SDimitry Andric if (auto *CI = dyn_cast<CallInst>(U)) 375fcaf7f86SDimitry Andric CI->mutateFunctionType(NewF->getFunctionType()); 376fcaf7f86SDimitry Andric U->replaceUsesOfWith(F, NewF); 377fcaf7f86SDimitry Andric } 378fcaf7f86SDimitry Andric return NewF; 379fcaf7f86SDimitry Andric } 380fcaf7f86SDimitry Andric 381fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) { 38206c3fb27SDimitry Andric bool Changed = false; 383fcaf7f86SDimitry Andric for (Function &F : M) 38406c3fb27SDimitry Andric Changed |= substituteIntrinsicCalls(&F); 385fcaf7f86SDimitry Andric 386fcaf7f86SDimitry Andric std::vector<Function *> FuncsWorklist; 387fcaf7f86SDimitry Andric for (auto &F : M) 388fcaf7f86SDimitry Andric FuncsWorklist.push_back(&F); 389fcaf7f86SDimitry Andric 39006c3fb27SDimitry Andric for (auto *F : FuncsWorklist) { 39106c3fb27SDimitry Andric Function *NewF = removeAggregateTypesFromSignature(F); 392fcaf7f86SDimitry Andric 39306c3fb27SDimitry Andric if (NewF != F) { 39406c3fb27SDimitry Andric F->eraseFromParent(); 39506c3fb27SDimitry Andric Changed = true; 396fcaf7f86SDimitry Andric } 397fcaf7f86SDimitry Andric } 398fcaf7f86SDimitry Andric return Changed; 399fcaf7f86SDimitry Andric } 400fcaf7f86SDimitry Andric 401*5f757f3fSDimitry Andric ModulePass * 402*5f757f3fSDimitry Andric llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { 403*5f757f3fSDimitry Andric return new SPIRVPrepareFunctions(TM); 404fcaf7f86SDimitry Andric } 405