1*fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// 2*fcaf7f86SDimitry Andric // 3*fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*fcaf7f86SDimitry Andric // 7*fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===// 8*fcaf7f86SDimitry Andric // 9*fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments 10*fcaf7f86SDimitry Andric // and/or return value. Also it substitutes some llvm intrinsic calls by 11*fcaf7f86SDimitry Andric // function calls, generating these functions as the translator does. 12*fcaf7f86SDimitry Andric // 13*fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify 14*fcaf7f86SDimitry Andric // GVs/functions. 15*fcaf7f86SDimitry Andric // 16*fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===// 17*fcaf7f86SDimitry Andric 18*fcaf7f86SDimitry Andric #include "SPIRV.h" 19*fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h" 20*fcaf7f86SDimitry Andric #include "SPIRVUtils.h" 21*fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h" 22*fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 23*fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h" 24*fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" 25*fcaf7f86SDimitry Andric 26*fcaf7f86SDimitry Andric using namespace llvm; 27*fcaf7f86SDimitry Andric 28*fcaf7f86SDimitry Andric namespace llvm { 29*fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &); 30*fcaf7f86SDimitry Andric } 31*fcaf7f86SDimitry Andric 32*fcaf7f86SDimitry Andric namespace { 33*fcaf7f86SDimitry Andric 34*fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass { 35*fcaf7f86SDimitry Andric Function *processFunctionSignature(Function *F); 36*fcaf7f86SDimitry Andric 37*fcaf7f86SDimitry Andric public: 38*fcaf7f86SDimitry Andric static char ID; 39*fcaf7f86SDimitry Andric SPIRVPrepareFunctions() : ModulePass(ID) { 40*fcaf7f86SDimitry Andric initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry()); 41*fcaf7f86SDimitry Andric } 42*fcaf7f86SDimitry Andric 43*fcaf7f86SDimitry Andric bool runOnModule(Module &M) override; 44*fcaf7f86SDimitry Andric 45*fcaf7f86SDimitry Andric StringRef getPassName() const override { return "SPIRV prepare functions"; } 46*fcaf7f86SDimitry Andric 47*fcaf7f86SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 48*fcaf7f86SDimitry Andric ModulePass::getAnalysisUsage(AU); 49*fcaf7f86SDimitry Andric } 50*fcaf7f86SDimitry Andric }; 51*fcaf7f86SDimitry Andric 52*fcaf7f86SDimitry Andric } // namespace 53*fcaf7f86SDimitry Andric 54*fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0; 55*fcaf7f86SDimitry Andric 56*fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions", 57*fcaf7f86SDimitry Andric "SPIRV prepare functions", false, false) 58*fcaf7f86SDimitry Andric 59*fcaf7f86SDimitry Andric Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) { 60*fcaf7f86SDimitry Andric IRBuilder<> B(F->getContext()); 61*fcaf7f86SDimitry Andric 62*fcaf7f86SDimitry Andric bool IsRetAggr = F->getReturnType()->isAggregateType(); 63*fcaf7f86SDimitry Andric bool HasAggrArg = 64*fcaf7f86SDimitry Andric std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) { 65*fcaf7f86SDimitry Andric return Arg.getType()->isAggregateType(); 66*fcaf7f86SDimitry Andric }); 67*fcaf7f86SDimitry Andric bool DoClone = IsRetAggr || HasAggrArg; 68*fcaf7f86SDimitry Andric if (!DoClone) 69*fcaf7f86SDimitry Andric return F; 70*fcaf7f86SDimitry Andric SmallVector<std::pair<int, Type *>, 4> ChangedTypes; 71*fcaf7f86SDimitry Andric Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); 72*fcaf7f86SDimitry Andric if (IsRetAggr) 73*fcaf7f86SDimitry Andric ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType())); 74*fcaf7f86SDimitry Andric SmallVector<Type *, 4> ArgTypes; 75*fcaf7f86SDimitry Andric for (const auto &Arg : F->args()) { 76*fcaf7f86SDimitry Andric if (Arg.getType()->isAggregateType()) { 77*fcaf7f86SDimitry Andric ArgTypes.push_back(B.getInt32Ty()); 78*fcaf7f86SDimitry Andric ChangedTypes.push_back( 79*fcaf7f86SDimitry Andric std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); 80*fcaf7f86SDimitry Andric } else 81*fcaf7f86SDimitry Andric ArgTypes.push_back(Arg.getType()); 82*fcaf7f86SDimitry Andric } 83*fcaf7f86SDimitry Andric FunctionType *NewFTy = 84*fcaf7f86SDimitry Andric FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg()); 85*fcaf7f86SDimitry Andric Function *NewF = 86*fcaf7f86SDimitry Andric Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent()); 87*fcaf7f86SDimitry Andric 88*fcaf7f86SDimitry Andric ValueToValueMapTy VMap; 89*fcaf7f86SDimitry Andric auto NewFArgIt = NewF->arg_begin(); 90*fcaf7f86SDimitry Andric for (auto &Arg : F->args()) { 91*fcaf7f86SDimitry Andric StringRef ArgName = Arg.getName(); 92*fcaf7f86SDimitry Andric NewFArgIt->setName(ArgName); 93*fcaf7f86SDimitry Andric VMap[&Arg] = &(*NewFArgIt++); 94*fcaf7f86SDimitry Andric } 95*fcaf7f86SDimitry Andric SmallVector<ReturnInst *, 8> Returns; 96*fcaf7f86SDimitry Andric 97*fcaf7f86SDimitry Andric CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, 98*fcaf7f86SDimitry Andric Returns); 99*fcaf7f86SDimitry Andric NewF->takeName(F); 100*fcaf7f86SDimitry Andric 101*fcaf7f86SDimitry Andric NamedMDNode *FuncMD = 102*fcaf7f86SDimitry Andric F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs"); 103*fcaf7f86SDimitry Andric SmallVector<Metadata *, 2> MDArgs; 104*fcaf7f86SDimitry Andric MDArgs.push_back(MDString::get(B.getContext(), NewF->getName())); 105*fcaf7f86SDimitry Andric for (auto &ChangedTyP : ChangedTypes) 106*fcaf7f86SDimitry Andric MDArgs.push_back(MDNode::get( 107*fcaf7f86SDimitry Andric B.getContext(), 108*fcaf7f86SDimitry Andric {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)), 109*fcaf7f86SDimitry Andric ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))})); 110*fcaf7f86SDimitry Andric MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs); 111*fcaf7f86SDimitry Andric FuncMD->addOperand(ThisFuncMD); 112*fcaf7f86SDimitry Andric 113*fcaf7f86SDimitry Andric for (auto *U : make_early_inc_range(F->users())) { 114*fcaf7f86SDimitry Andric if (auto *CI = dyn_cast<CallInst>(U)) 115*fcaf7f86SDimitry Andric CI->mutateFunctionType(NewF->getFunctionType()); 116*fcaf7f86SDimitry Andric U->replaceUsesOfWith(F, NewF); 117*fcaf7f86SDimitry Andric } 118*fcaf7f86SDimitry Andric return NewF; 119*fcaf7f86SDimitry Andric } 120*fcaf7f86SDimitry Andric 121*fcaf7f86SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { 122*fcaf7f86SDimitry Andric Function *IntrinsicFunc = II->getCalledFunction(); 123*fcaf7f86SDimitry Andric assert(IntrinsicFunc && "Missing function"); 124*fcaf7f86SDimitry Andric std::string FuncName = IntrinsicFunc->getName().str(); 125*fcaf7f86SDimitry Andric std::replace(FuncName.begin(), FuncName.end(), '.', '_'); 126*fcaf7f86SDimitry Andric FuncName = "spirv." + FuncName; 127*fcaf7f86SDimitry Andric return FuncName; 128*fcaf7f86SDimitry Andric } 129*fcaf7f86SDimitry Andric 130*fcaf7f86SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy, 131*fcaf7f86SDimitry Andric ArrayRef<Type *> ArgTypes, 132*fcaf7f86SDimitry Andric StringRef Name) { 133*fcaf7f86SDimitry Andric FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false); 134*fcaf7f86SDimitry Andric Function *F = M->getFunction(Name); 135*fcaf7f86SDimitry Andric if (F && F->getFunctionType() == FT) 136*fcaf7f86SDimitry Andric return F; 137*fcaf7f86SDimitry Andric Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M); 138*fcaf7f86SDimitry Andric if (F) 139*fcaf7f86SDimitry Andric NewF->setDSOLocal(F->isDSOLocal()); 140*fcaf7f86SDimitry Andric NewF->setCallingConv(CallingConv::SPIR_FUNC); 141*fcaf7f86SDimitry Andric return NewF; 142*fcaf7f86SDimitry Andric } 143*fcaf7f86SDimitry Andric 144*fcaf7f86SDimitry Andric static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { 145*fcaf7f86SDimitry Andric // Get a separate function - otherwise, we'd have to rework the CFG of the 146*fcaf7f86SDimitry Andric // current one. Then simply replace the intrinsic uses with a call to the new 147*fcaf7f86SDimitry Andric // function. 148*fcaf7f86SDimitry Andric // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) 149*fcaf7f86SDimitry Andric FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); 150*fcaf7f86SDimitry Andric Type *FSHRetTy = FSHFuncTy->getReturnType(); 151*fcaf7f86SDimitry Andric const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic); 152*fcaf7f86SDimitry Andric Function *FSHFunc = 153*fcaf7f86SDimitry Andric getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName); 154*fcaf7f86SDimitry Andric 155*fcaf7f86SDimitry Andric if (!FSHFunc->empty()) { 156*fcaf7f86SDimitry Andric FSHIntrinsic->setCalledFunction(FSHFunc); 157*fcaf7f86SDimitry Andric return; 158*fcaf7f86SDimitry Andric } 159*fcaf7f86SDimitry Andric BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc); 160*fcaf7f86SDimitry Andric IRBuilder<> IRB(RotateBB); 161*fcaf7f86SDimitry Andric Type *Ty = FSHFunc->getReturnType(); 162*fcaf7f86SDimitry Andric // Build the actual funnel shift rotate logic. 163*fcaf7f86SDimitry Andric // In the comments, "int" is used interchangeably with "vector of int 164*fcaf7f86SDimitry Andric // elements". 165*fcaf7f86SDimitry Andric FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty); 166*fcaf7f86SDimitry Andric Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; 167*fcaf7f86SDimitry Andric unsigned BitWidth = IntTy->getIntegerBitWidth(); 168*fcaf7f86SDimitry Andric ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth}); 169*fcaf7f86SDimitry Andric Value *BitWidthForInsts = 170*fcaf7f86SDimitry Andric VectorTy 171*fcaf7f86SDimitry Andric ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant) 172*fcaf7f86SDimitry Andric : BitWidthConstant; 173*fcaf7f86SDimitry Andric Value *RotateModVal = 174*fcaf7f86SDimitry Andric IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts); 175*fcaf7f86SDimitry Andric Value *FirstShift = nullptr, *SecShift = nullptr; 176*fcaf7f86SDimitry Andric if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 177*fcaf7f86SDimitry Andric // Shift the less significant number right, the "rotate" number of bits 178*fcaf7f86SDimitry Andric // will be 0-filled on the left as a result of this regular shift. 179*fcaf7f86SDimitry Andric FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal); 180*fcaf7f86SDimitry Andric } else { 181*fcaf7f86SDimitry Andric // Shift the more significant number left, the "rotate" number of bits 182*fcaf7f86SDimitry Andric // will be 0-filled on the right as a result of this regular shift. 183*fcaf7f86SDimitry Andric FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal); 184*fcaf7f86SDimitry Andric } 185*fcaf7f86SDimitry Andric // We want the "rotate" number of the more significant int's LSBs (MSBs) to 186*fcaf7f86SDimitry Andric // occupy the leftmost (rightmost) "0 space" left by the previous operation. 187*fcaf7f86SDimitry Andric // Therefore, subtract the "rotate" number from the integer bitsize... 188*fcaf7f86SDimitry Andric Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal); 189*fcaf7f86SDimitry Andric if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { 190*fcaf7f86SDimitry Andric // ...and left-shift the more significant int by this number, zero-filling 191*fcaf7f86SDimitry Andric // the LSBs. 192*fcaf7f86SDimitry Andric SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal); 193*fcaf7f86SDimitry Andric } else { 194*fcaf7f86SDimitry Andric // ...and right-shift the less significant int by this number, zero-filling 195*fcaf7f86SDimitry Andric // the MSBs. 196*fcaf7f86SDimitry Andric SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal); 197*fcaf7f86SDimitry Andric } 198*fcaf7f86SDimitry Andric // A simple binary addition of the shifted ints yields the final result. 199*fcaf7f86SDimitry Andric IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift)); 200*fcaf7f86SDimitry Andric 201*fcaf7f86SDimitry Andric FSHIntrinsic->setCalledFunction(FSHFunc); 202*fcaf7f86SDimitry Andric } 203*fcaf7f86SDimitry Andric 204*fcaf7f86SDimitry Andric static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) { 205*fcaf7f86SDimitry Andric // The function body is already created. 206*fcaf7f86SDimitry Andric if (!UMulFunc->empty()) 207*fcaf7f86SDimitry Andric return; 208*fcaf7f86SDimitry Andric 209*fcaf7f86SDimitry Andric BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc); 210*fcaf7f86SDimitry Andric IRBuilder<> IRB(EntryBB); 211*fcaf7f86SDimitry Andric // Build the actual unsigned multiplication logic with the overflow 212*fcaf7f86SDimitry Andric // indication. Do unsigned multiplication Mul = A * B. Then check 213*fcaf7f86SDimitry Andric // if unsigned division Div = Mul / A is not equal to B. If so, 214*fcaf7f86SDimitry Andric // then overflow has happened. 215*fcaf7f86SDimitry Andric Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1)); 216*fcaf7f86SDimitry Andric Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0)); 217*fcaf7f86SDimitry Andric Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div); 218*fcaf7f86SDimitry Andric 219*fcaf7f86SDimitry Andric // umul.with.overflow intrinsic return a structure, where the first element 220*fcaf7f86SDimitry Andric // is the multiplication result, and the second is an overflow bit. 221*fcaf7f86SDimitry Andric Type *StructTy = UMulFunc->getReturnType(); 222*fcaf7f86SDimitry Andric Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0}); 223*fcaf7f86SDimitry Andric Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1}); 224*fcaf7f86SDimitry Andric IRB.CreateRet(Res); 225*fcaf7f86SDimitry Andric } 226*fcaf7f86SDimitry Andric 227*fcaf7f86SDimitry Andric static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) { 228*fcaf7f86SDimitry Andric // Get a separate function - otherwise, we'd have to rework the CFG of the 229*fcaf7f86SDimitry Andric // current one. Then simply replace the intrinsic uses with a call to the new 230*fcaf7f86SDimitry Andric // function. 231*fcaf7f86SDimitry Andric FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType(); 232*fcaf7f86SDimitry Andric Type *FSHLRetTy = UMulFuncTy->getReturnType(); 233*fcaf7f86SDimitry Andric const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic); 234*fcaf7f86SDimitry Andric Function *UMulFunc = 235*fcaf7f86SDimitry Andric getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName); 236*fcaf7f86SDimitry Andric buildUMulWithOverflowFunc(M, UMulFunc); 237*fcaf7f86SDimitry Andric UMulIntrinsic->setCalledFunction(UMulFunc); 238*fcaf7f86SDimitry Andric } 239*fcaf7f86SDimitry Andric 240*fcaf7f86SDimitry Andric static void substituteIntrinsicCalls(Module *M, Function *F) { 241*fcaf7f86SDimitry Andric for (BasicBlock &BB : *F) { 242*fcaf7f86SDimitry Andric for (Instruction &I : BB) { 243*fcaf7f86SDimitry Andric auto Call = dyn_cast<CallInst>(&I); 244*fcaf7f86SDimitry Andric if (!Call) 245*fcaf7f86SDimitry Andric continue; 246*fcaf7f86SDimitry Andric Call->setTailCall(false); 247*fcaf7f86SDimitry Andric Function *CF = Call->getCalledFunction(); 248*fcaf7f86SDimitry Andric if (!CF || !CF->isIntrinsic()) 249*fcaf7f86SDimitry Andric continue; 250*fcaf7f86SDimitry Andric auto *II = cast<IntrinsicInst>(Call); 251*fcaf7f86SDimitry Andric if (II->getIntrinsicID() == Intrinsic::fshl || 252*fcaf7f86SDimitry Andric II->getIntrinsicID() == Intrinsic::fshr) 253*fcaf7f86SDimitry Andric lowerFunnelShifts(M, II); 254*fcaf7f86SDimitry Andric else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) 255*fcaf7f86SDimitry Andric lowerUMulWithOverflow(M, II); 256*fcaf7f86SDimitry Andric } 257*fcaf7f86SDimitry Andric } 258*fcaf7f86SDimitry Andric } 259*fcaf7f86SDimitry Andric 260*fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) { 261*fcaf7f86SDimitry Andric for (Function &F : M) 262*fcaf7f86SDimitry Andric substituteIntrinsicCalls(&M, &F); 263*fcaf7f86SDimitry Andric 264*fcaf7f86SDimitry Andric std::vector<Function *> FuncsWorklist; 265*fcaf7f86SDimitry Andric bool Changed = false; 266*fcaf7f86SDimitry Andric for (auto &F : M) 267*fcaf7f86SDimitry Andric FuncsWorklist.push_back(&F); 268*fcaf7f86SDimitry Andric 269*fcaf7f86SDimitry Andric for (auto *Func : FuncsWorklist) { 270*fcaf7f86SDimitry Andric Function *F = processFunctionSignature(Func); 271*fcaf7f86SDimitry Andric 272*fcaf7f86SDimitry Andric bool CreatedNewF = F != Func; 273*fcaf7f86SDimitry Andric 274*fcaf7f86SDimitry Andric if (Func->isDeclaration()) { 275*fcaf7f86SDimitry Andric Changed |= CreatedNewF; 276*fcaf7f86SDimitry Andric continue; 277*fcaf7f86SDimitry Andric } 278*fcaf7f86SDimitry Andric 279*fcaf7f86SDimitry Andric if (CreatedNewF) 280*fcaf7f86SDimitry Andric Func->eraseFromParent(); 281*fcaf7f86SDimitry Andric } 282*fcaf7f86SDimitry Andric 283*fcaf7f86SDimitry Andric return Changed; 284*fcaf7f86SDimitry Andric } 285*fcaf7f86SDimitry Andric 286*fcaf7f86SDimitry Andric ModulePass *llvm::createSPIRVPrepareFunctionsPass() { 287*fcaf7f86SDimitry Andric return new SPIRVPrepareFunctions(); 288*fcaf7f86SDimitry Andric } 289