181ad6265SDimitry Andric //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 281ad6265SDimitry Andric // 381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 681ad6265SDimitry Andric // 781ad6265SDimitry Andric //===----------------------------------------------------------------------===// 881ad6265SDimitry Andric // 981ad6265SDimitry Andric // This file implements the lowering of LLVM calls to machine code calls for 1081ad6265SDimitry Andric // GlobalISel. 1181ad6265SDimitry Andric // 1281ad6265SDimitry Andric //===----------------------------------------------------------------------===// 1381ad6265SDimitry Andric 1481ad6265SDimitry Andric #include "SPIRVCallLowering.h" 1581ad6265SDimitry Andric #include "MCTargetDesc/SPIRVBaseInfo.h" 1681ad6265SDimitry Andric #include "SPIRV.h" 17*bdd1243dSDimitry Andric #include "SPIRVBuiltins.h" 1881ad6265SDimitry Andric #include "SPIRVGlobalRegistry.h" 1981ad6265SDimitry Andric #include "SPIRVISelLowering.h" 2081ad6265SDimitry Andric #include "SPIRVRegisterInfo.h" 2181ad6265SDimitry Andric #include "SPIRVSubtarget.h" 2281ad6265SDimitry Andric #include "SPIRVUtils.h" 2381ad6265SDimitry Andric #include "llvm/CodeGen/FunctionLoweringInfo.h" 24*bdd1243dSDimitry Andric #include "llvm/Support/ModRef.h" 2581ad6265SDimitry Andric 2681ad6265SDimitry Andric using namespace llvm; 2781ad6265SDimitry Andric 2881ad6265SDimitry Andric SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 2981ad6265SDimitry Andric SPIRVGlobalRegistry *GR) 30fcaf7f86SDimitry Andric : CallLowering(&TLI), GR(GR) {} 3181ad6265SDimitry Andric 3281ad6265SDimitry Andric bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 3381ad6265SDimitry Andric const Value *Val, ArrayRef<Register> VRegs, 3481ad6265SDimitry Andric FunctionLoweringInfo &FLI, 3581ad6265SDimitry Andric Register SwiftErrorVReg) const { 3681ad6265SDimitry Andric // Currently all return types should use a single register. 3781ad6265SDimitry Andric // TODO: handle the case of multiple registers. 3881ad6265SDimitry Andric if (VRegs.size() > 1) 3981ad6265SDimitry Andric return false; 40fcaf7f86SDimitry Andric if (Val) { 41fcaf7f86SDimitry Andric const auto &STI = MIRBuilder.getMF().getSubtarget(); 4281ad6265SDimitry Andric return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 4381ad6265SDimitry Andric .addUse(VRegs[0]) 44fcaf7f86SDimitry Andric .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 45fcaf7f86SDimitry Andric *STI.getRegBankInfo()); 46fcaf7f86SDimitry Andric } 4781ad6265SDimitry Andric MIRBuilder.buildInstr(SPIRV::OpReturn); 4881ad6265SDimitry Andric return true; 4981ad6265SDimitry Andric } 5081ad6265SDimitry Andric 5181ad6265SDimitry Andric // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 5281ad6265SDimitry Andric static uint32_t getFunctionControl(const Function &F) { 53*bdd1243dSDimitry Andric MemoryEffects MemEffects = F.getMemoryEffects(); 54*bdd1243dSDimitry Andric 5581ad6265SDimitry Andric uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 56*bdd1243dSDimitry Andric 57*bdd1243dSDimitry Andric if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) 5881ad6265SDimitry Andric FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 59*bdd1243dSDimitry Andric else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) 60*bdd1243dSDimitry Andric FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 61*bdd1243dSDimitry Andric 62*bdd1243dSDimitry Andric if (MemEffects.doesNotAccessMemory()) 63*bdd1243dSDimitry Andric FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 64*bdd1243dSDimitry Andric else if (MemEffects.onlyReadsMemory()) 65*bdd1243dSDimitry Andric FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 66*bdd1243dSDimitry Andric 6781ad6265SDimitry Andric return FuncControl; 6881ad6265SDimitry Andric } 6981ad6265SDimitry Andric 70fcaf7f86SDimitry Andric static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 71fcaf7f86SDimitry Andric if (MD->getNumOperands() > NumOp) { 72fcaf7f86SDimitry Andric auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 73fcaf7f86SDimitry Andric if (CMeta) 74fcaf7f86SDimitry Andric return dyn_cast<ConstantInt>(CMeta->getValue()); 75fcaf7f86SDimitry Andric } 76fcaf7f86SDimitry Andric return nullptr; 77fcaf7f86SDimitry Andric } 78fcaf7f86SDimitry Andric 79fcaf7f86SDimitry Andric // This code restores function args/retvalue types for composite cases 80fcaf7f86SDimitry Andric // because the final types should still be aggregate whereas they're i32 81fcaf7f86SDimitry Andric // during the translation to cope with aggregate flattening etc. 82fcaf7f86SDimitry Andric static FunctionType *getOriginalFunctionType(const Function &F) { 83fcaf7f86SDimitry Andric auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 84fcaf7f86SDimitry Andric if (NamedMD == nullptr) 85fcaf7f86SDimitry Andric return F.getFunctionType(); 86fcaf7f86SDimitry Andric 87fcaf7f86SDimitry Andric Type *RetTy = F.getFunctionType()->getReturnType(); 88fcaf7f86SDimitry Andric SmallVector<Type *, 4> ArgTypes; 89fcaf7f86SDimitry Andric for (auto &Arg : F.args()) 90fcaf7f86SDimitry Andric ArgTypes.push_back(Arg.getType()); 91fcaf7f86SDimitry Andric 92fcaf7f86SDimitry Andric auto ThisFuncMDIt = 93fcaf7f86SDimitry Andric std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 94fcaf7f86SDimitry Andric return isa<MDString>(N->getOperand(0)) && 95fcaf7f86SDimitry Andric cast<MDString>(N->getOperand(0))->getString() == F.getName(); 96fcaf7f86SDimitry Andric }); 97fcaf7f86SDimitry Andric // TODO: probably one function can have numerous type mutations, 98fcaf7f86SDimitry Andric // so we should support this. 99fcaf7f86SDimitry Andric if (ThisFuncMDIt != NamedMD->op_end()) { 100fcaf7f86SDimitry Andric auto *ThisFuncMD = *ThisFuncMDIt; 101fcaf7f86SDimitry Andric MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 102fcaf7f86SDimitry Andric assert(MD && "MDNode operand is expected"); 103fcaf7f86SDimitry Andric ConstantInt *Const = getConstInt(MD, 0); 104fcaf7f86SDimitry Andric if (Const) { 105fcaf7f86SDimitry Andric auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 106fcaf7f86SDimitry Andric assert(CMeta && "ConstantAsMetadata operand is expected"); 107fcaf7f86SDimitry Andric assert(Const->getSExtValue() >= -1); 108fcaf7f86SDimitry Andric // Currently -1 indicates return value, greater values mean 109fcaf7f86SDimitry Andric // argument numbers. 110fcaf7f86SDimitry Andric if (Const->getSExtValue() == -1) 111fcaf7f86SDimitry Andric RetTy = CMeta->getType(); 112fcaf7f86SDimitry Andric else 113fcaf7f86SDimitry Andric ArgTypes[Const->getSExtValue()] = CMeta->getType(); 114fcaf7f86SDimitry Andric } 115fcaf7f86SDimitry Andric } 116fcaf7f86SDimitry Andric 117fcaf7f86SDimitry Andric return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 118fcaf7f86SDimitry Andric } 119fcaf7f86SDimitry Andric 120*bdd1243dSDimitry Andric static MDString *getKernelArgAttribute(const Function &KernelFunction, 121*bdd1243dSDimitry Andric unsigned ArgIdx, 122*bdd1243dSDimitry Andric const StringRef AttributeName) { 123*bdd1243dSDimitry Andric assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && 124*bdd1243dSDimitry Andric "Kernel attributes are attached/belong only to kernel functions"); 125*bdd1243dSDimitry Andric 126*bdd1243dSDimitry Andric // Lookup the argument attribute in metadata attached to the kernel function. 127*bdd1243dSDimitry Andric MDNode *Node = KernelFunction.getMetadata(AttributeName); 128*bdd1243dSDimitry Andric if (Node && ArgIdx < Node->getNumOperands()) 129*bdd1243dSDimitry Andric return cast<MDString>(Node->getOperand(ArgIdx)); 130*bdd1243dSDimitry Andric 131*bdd1243dSDimitry Andric // Sometimes metadata containing kernel attributes is not attached to the 132*bdd1243dSDimitry Andric // function, but can be found in the named module-level metadata instead. 133*bdd1243dSDimitry Andric // For example: 134*bdd1243dSDimitry Andric // !opencl.kernels = !{!0} 135*bdd1243dSDimitry Andric // !0 = !{void ()* @someKernelFunction, !1, ...} 136*bdd1243dSDimitry Andric // !1 = !{!"kernel_arg_addr_space", ...} 137*bdd1243dSDimitry Andric // In this case the actual index of searched argument attribute is ArgIdx + 1, 138*bdd1243dSDimitry Andric // since the first metadata node operand is occupied by attribute name 139*bdd1243dSDimitry Andric // ("kernel_arg_addr_space" in the example above). 140*bdd1243dSDimitry Andric unsigned MDArgIdx = ArgIdx + 1; 141*bdd1243dSDimitry Andric NamedMDNode *OpenCLKernelsMD = 142*bdd1243dSDimitry Andric KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); 143*bdd1243dSDimitry Andric if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) 144*bdd1243dSDimitry Andric return nullptr; 145*bdd1243dSDimitry Andric 146*bdd1243dSDimitry Andric // KernelToMDNodeList contains kernel function declarations followed by 147*bdd1243dSDimitry Andric // corresponding MDNodes for each attribute. Search only MDNodes "belonging" 148*bdd1243dSDimitry Andric // to the currently lowered kernel function. 149*bdd1243dSDimitry Andric MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); 150*bdd1243dSDimitry Andric bool FoundLoweredKernelFunction = false; 151*bdd1243dSDimitry Andric for (const MDOperand &Operand : KernelToMDNodeList->operands()) { 152*bdd1243dSDimitry Andric ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand); 153*bdd1243dSDimitry Andric if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() == 154*bdd1243dSDimitry Andric KernelFunction.getName()) { 155*bdd1243dSDimitry Andric FoundLoweredKernelFunction = true; 156*bdd1243dSDimitry Andric continue; 157*bdd1243dSDimitry Andric } 158*bdd1243dSDimitry Andric if (MaybeValue && FoundLoweredKernelFunction) 159*bdd1243dSDimitry Andric return nullptr; 160*bdd1243dSDimitry Andric 161*bdd1243dSDimitry Andric MDNode *MaybeNode = dyn_cast<MDNode>(Operand); 162*bdd1243dSDimitry Andric if (FoundLoweredKernelFunction && MaybeNode && 163*bdd1243dSDimitry Andric cast<MDString>(MaybeNode->getOperand(0))->getString() == 164*bdd1243dSDimitry Andric AttributeName && 165*bdd1243dSDimitry Andric MDArgIdx < MaybeNode->getNumOperands()) 166*bdd1243dSDimitry Andric return cast<MDString>(MaybeNode->getOperand(MDArgIdx)); 167*bdd1243dSDimitry Andric } 168*bdd1243dSDimitry Andric return nullptr; 169*bdd1243dSDimitry Andric } 170*bdd1243dSDimitry Andric 171*bdd1243dSDimitry Andric static SPIRV::AccessQualifier::AccessQualifier 172*bdd1243dSDimitry Andric getArgAccessQual(const Function &F, unsigned ArgIdx) { 173*bdd1243dSDimitry Andric if (F.getCallingConv() != CallingConv::SPIR_KERNEL) 174*bdd1243dSDimitry Andric return SPIRV::AccessQualifier::ReadWrite; 175*bdd1243dSDimitry Andric 176*bdd1243dSDimitry Andric MDString *ArgAttribute = 177*bdd1243dSDimitry Andric getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); 178*bdd1243dSDimitry Andric if (!ArgAttribute) 179*bdd1243dSDimitry Andric return SPIRV::AccessQualifier::ReadWrite; 180*bdd1243dSDimitry Andric 181*bdd1243dSDimitry Andric if (ArgAttribute->getString().compare("read_only") == 0) 182*bdd1243dSDimitry Andric return SPIRV::AccessQualifier::ReadOnly; 183*bdd1243dSDimitry Andric if (ArgAttribute->getString().compare("write_only") == 0) 184*bdd1243dSDimitry Andric return SPIRV::AccessQualifier::WriteOnly; 185*bdd1243dSDimitry Andric return SPIRV::AccessQualifier::ReadWrite; 186*bdd1243dSDimitry Andric } 187*bdd1243dSDimitry Andric 188*bdd1243dSDimitry Andric static std::vector<SPIRV::Decoration::Decoration> 189*bdd1243dSDimitry Andric getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { 190*bdd1243dSDimitry Andric MDString *ArgAttribute = 191*bdd1243dSDimitry Andric getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); 192*bdd1243dSDimitry Andric if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) 193*bdd1243dSDimitry Andric return {SPIRV::Decoration::Volatile}; 194*bdd1243dSDimitry Andric return {}; 195*bdd1243dSDimitry Andric } 196*bdd1243dSDimitry Andric 197*bdd1243dSDimitry Andric static Type *getArgType(const Function &F, unsigned ArgIdx) { 198*bdd1243dSDimitry Andric Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); 199*bdd1243dSDimitry Andric if (F.getCallingConv() != CallingConv::SPIR_KERNEL || 200*bdd1243dSDimitry Andric isSpecialOpaqueType(OriginalArgType)) 201*bdd1243dSDimitry Andric return OriginalArgType; 202*bdd1243dSDimitry Andric 203*bdd1243dSDimitry Andric MDString *MDKernelArgType = 204*bdd1243dSDimitry Andric getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); 205*bdd1243dSDimitry Andric if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t")) 206*bdd1243dSDimitry Andric return OriginalArgType; 207*bdd1243dSDimitry Andric 208*bdd1243dSDimitry Andric std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str(); 209*bdd1243dSDimitry Andric Type *ExistingOpaqueType = 210*bdd1243dSDimitry Andric StructType::getTypeByName(F.getContext(), KernelArgTypeStr); 211*bdd1243dSDimitry Andric return ExistingOpaqueType 212*bdd1243dSDimitry Andric ? ExistingOpaqueType 213*bdd1243dSDimitry Andric : StructType::create(F.getContext(), KernelArgTypeStr); 214*bdd1243dSDimitry Andric } 215*bdd1243dSDimitry Andric 21681ad6265SDimitry Andric bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 21781ad6265SDimitry Andric const Function &F, 21881ad6265SDimitry Andric ArrayRef<ArrayRef<Register>> VRegs, 21981ad6265SDimitry Andric FunctionLoweringInfo &FLI) const { 22081ad6265SDimitry Andric assert(GR && "Must initialize the SPIRV type registry before lowering args."); 221753f127fSDimitry Andric GR->setCurrentFunc(MIRBuilder.getMF()); 22281ad6265SDimitry Andric 22381ad6265SDimitry Andric // Assign types and names to all args, and store their types for later. 224fcaf7f86SDimitry Andric FunctionType *FTy = getOriginalFunctionType(F); 225fcaf7f86SDimitry Andric SmallVector<SPIRVType *, 4> ArgTypeVRegs; 22681ad6265SDimitry Andric if (VRegs.size() > 0) { 22781ad6265SDimitry Andric unsigned i = 0; 22881ad6265SDimitry Andric for (const auto &Arg : F.args()) { 22981ad6265SDimitry Andric // Currently formal args should use single registers. 23081ad6265SDimitry Andric // TODO: handle the case of multiple registers. 23181ad6265SDimitry Andric if (VRegs[i].size() > 1) 23281ad6265SDimitry Andric return false; 233*bdd1243dSDimitry Andric SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = 234*bdd1243dSDimitry Andric getArgAccessQual(F, i); 235*bdd1243dSDimitry Andric auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], 236*bdd1243dSDimitry Andric MIRBuilder, ArgAccessQual); 237fcaf7f86SDimitry Andric ArgTypeVRegs.push_back(SpirvTy); 23881ad6265SDimitry Andric 23981ad6265SDimitry Andric if (Arg.hasName()) 24081ad6265SDimitry Andric buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 24181ad6265SDimitry Andric if (Arg.getType()->isPointerTy()) { 24281ad6265SDimitry Andric auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 24381ad6265SDimitry Andric if (DerefBytes != 0) 24481ad6265SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, 24581ad6265SDimitry Andric SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 24681ad6265SDimitry Andric } 24781ad6265SDimitry Andric if (Arg.hasAttribute(Attribute::Alignment)) { 248fcaf7f86SDimitry Andric auto Alignment = static_cast<unsigned>( 249fcaf7f86SDimitry Andric Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 25081ad6265SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 251fcaf7f86SDimitry Andric {Alignment}); 25281ad6265SDimitry Andric } 25381ad6265SDimitry Andric if (Arg.hasAttribute(Attribute::ReadOnly)) { 25481ad6265SDimitry Andric auto Attr = 25581ad6265SDimitry Andric static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 25681ad6265SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, 25781ad6265SDimitry Andric SPIRV::Decoration::FuncParamAttr, {Attr}); 25881ad6265SDimitry Andric } 25981ad6265SDimitry Andric if (Arg.hasAttribute(Attribute::ZExt)) { 26081ad6265SDimitry Andric auto Attr = 26181ad6265SDimitry Andric static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 26281ad6265SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, 26381ad6265SDimitry Andric SPIRV::Decoration::FuncParamAttr, {Attr}); 26481ad6265SDimitry Andric } 265fcaf7f86SDimitry Andric if (Arg.hasAttribute(Attribute::NoAlias)) { 266fcaf7f86SDimitry Andric auto Attr = 267fcaf7f86SDimitry Andric static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 268fcaf7f86SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, 269fcaf7f86SDimitry Andric SPIRV::Decoration::FuncParamAttr, {Attr}); 270fcaf7f86SDimitry Andric } 271*bdd1243dSDimitry Andric 272*bdd1243dSDimitry Andric if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 273*bdd1243dSDimitry Andric std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = 274*bdd1243dSDimitry Andric getKernelArgTypeQual(F, i); 275*bdd1243dSDimitry Andric for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) 276*bdd1243dSDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); 277fcaf7f86SDimitry Andric } 278*bdd1243dSDimitry Andric 279*bdd1243dSDimitry Andric MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); 280fcaf7f86SDimitry Andric if (Node && i < Node->getNumOperands() && 281fcaf7f86SDimitry Andric isa<MDNode>(Node->getOperand(i))) { 282fcaf7f86SDimitry Andric MDNode *MD = cast<MDNode>(Node->getOperand(i)); 283fcaf7f86SDimitry Andric for (const MDOperand &MDOp : MD->operands()) { 284fcaf7f86SDimitry Andric MDNode *MD2 = dyn_cast<MDNode>(MDOp); 285fcaf7f86SDimitry Andric assert(MD2 && "Metadata operand is expected"); 286fcaf7f86SDimitry Andric ConstantInt *Const = getConstInt(MD2, 0); 287fcaf7f86SDimitry Andric assert(Const && "MDOperand should be ConstantInt"); 288*bdd1243dSDimitry Andric auto Dec = 289*bdd1243dSDimitry Andric static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); 290fcaf7f86SDimitry Andric std::vector<uint32_t> DecVec; 291fcaf7f86SDimitry Andric for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 292fcaf7f86SDimitry Andric ConstantInt *Const = getConstInt(MD2, j); 293fcaf7f86SDimitry Andric assert(Const && "MDOperand should be ConstantInt"); 294fcaf7f86SDimitry Andric DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 295fcaf7f86SDimitry Andric } 296fcaf7f86SDimitry Andric buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 297fcaf7f86SDimitry Andric } 298fcaf7f86SDimitry Andric } 29981ad6265SDimitry Andric ++i; 30081ad6265SDimitry Andric } 30181ad6265SDimitry Andric } 30281ad6265SDimitry Andric 30381ad6265SDimitry Andric // Generate a SPIR-V type for the function. 30481ad6265SDimitry Andric auto MRI = MIRBuilder.getMRI(); 30581ad6265SDimitry Andric Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 30681ad6265SDimitry Andric MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 307753f127fSDimitry Andric if (F.isDeclaration()) 308753f127fSDimitry Andric GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 309fcaf7f86SDimitry Andric SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); 310fcaf7f86SDimitry Andric SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 311fcaf7f86SDimitry Andric FTy, RetTy, ArgTypeVRegs, MIRBuilder); 31281ad6265SDimitry Andric 31381ad6265SDimitry Andric // Build the OpTypeFunction declaring it. 31481ad6265SDimitry Andric uint32_t FuncControl = getFunctionControl(F); 31581ad6265SDimitry Andric 31681ad6265SDimitry Andric MIRBuilder.buildInstr(SPIRV::OpFunction) 31781ad6265SDimitry Andric .addDef(FuncVReg) 318fcaf7f86SDimitry Andric .addUse(GR->getSPIRVTypeID(RetTy)) 31981ad6265SDimitry Andric .addImm(FuncControl) 32081ad6265SDimitry Andric .addUse(GR->getSPIRVTypeID(FuncTy)); 32181ad6265SDimitry Andric 32281ad6265SDimitry Andric // Add OpFunctionParameters. 323fcaf7f86SDimitry Andric int i = 0; 324fcaf7f86SDimitry Andric for (const auto &Arg : F.args()) { 32581ad6265SDimitry Andric assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 32681ad6265SDimitry Andric MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 32781ad6265SDimitry Andric MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 32881ad6265SDimitry Andric .addDef(VRegs[i][0]) 329fcaf7f86SDimitry Andric .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 330753f127fSDimitry Andric if (F.isDeclaration()) 331fcaf7f86SDimitry Andric GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); 332fcaf7f86SDimitry Andric i++; 33381ad6265SDimitry Andric } 33481ad6265SDimitry Andric // Name the function. 33581ad6265SDimitry Andric if (F.hasName()) 33681ad6265SDimitry Andric buildOpName(FuncVReg, F.getName(), MIRBuilder); 33781ad6265SDimitry Andric 33881ad6265SDimitry Andric // Handle entry points and function linkage. 33981ad6265SDimitry Andric if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 34081ad6265SDimitry Andric auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 34181ad6265SDimitry Andric .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel)) 34281ad6265SDimitry Andric .addUse(FuncVReg); 34381ad6265SDimitry Andric addStringImm(F.getName(), MIB); 34481ad6265SDimitry Andric } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || 34581ad6265SDimitry Andric F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { 34681ad6265SDimitry Andric auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import 34781ad6265SDimitry Andric : SPIRV::LinkageType::Export; 34881ad6265SDimitry Andric buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 34981ad6265SDimitry Andric {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 35081ad6265SDimitry Andric } 35181ad6265SDimitry Andric 35281ad6265SDimitry Andric return true; 35381ad6265SDimitry Andric } 35481ad6265SDimitry Andric 35581ad6265SDimitry Andric bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 35681ad6265SDimitry Andric CallLoweringInfo &Info) const { 35781ad6265SDimitry Andric // Currently call returns should have single vregs. 35881ad6265SDimitry Andric // TODO: handle the case of multiple registers. 35981ad6265SDimitry Andric if (Info.OrigRet.Regs.size() > 1) 36081ad6265SDimitry Andric return false; 361fcaf7f86SDimitry Andric MachineFunction &MF = MIRBuilder.getMF(); 362fcaf7f86SDimitry Andric GR->setCurrentFunc(MF); 363fcaf7f86SDimitry Andric FunctionType *FTy = nullptr; 364fcaf7f86SDimitry Andric const Function *CF = nullptr; 36581ad6265SDimitry Andric 36681ad6265SDimitry Andric // Emit a regular OpFunctionCall. If it's an externally declared function, 367fcaf7f86SDimitry Andric // be sure to emit its type and function declaration here. It will be hoisted 368fcaf7f86SDimitry Andric // globally later. 36981ad6265SDimitry Andric if (Info.Callee.isGlobal()) { 370fcaf7f86SDimitry Andric CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 37181ad6265SDimitry Andric // TODO: support constexpr casts and indirect calls. 37281ad6265SDimitry Andric if (CF == nullptr) 37381ad6265SDimitry Andric return false; 374fcaf7f86SDimitry Andric FTy = getOriginalFunctionType(*CF); 375fcaf7f86SDimitry Andric } 376fcaf7f86SDimitry Andric 377fcaf7f86SDimitry Andric Register ResVReg = 378fcaf7f86SDimitry Andric Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 379*bdd1243dSDimitry Andric std::string FuncName = Info.Callee.getGlobal()->getName().str(); 380*bdd1243dSDimitry Andric std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); 381*bdd1243dSDimitry Andric const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); 382*bdd1243dSDimitry Andric // TODO: check that it's OCL builtin, then apply OpenCL_std. 383*bdd1243dSDimitry Andric if (!DemangledName.empty() && CF && CF->isDeclaration() && 384*bdd1243dSDimitry Andric ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 385*bdd1243dSDimitry Andric const Type *OrigRetTy = Info.OrigRet.Ty; 386*bdd1243dSDimitry Andric if (FTy) 387*bdd1243dSDimitry Andric OrigRetTy = FTy->getReturnType(); 388*bdd1243dSDimitry Andric SmallVector<Register, 8> ArgVRegs; 389*bdd1243dSDimitry Andric for (auto Arg : Info.OrigArgs) { 390*bdd1243dSDimitry Andric assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 391*bdd1243dSDimitry Andric ArgVRegs.push_back(Arg.Regs[0]); 392*bdd1243dSDimitry Andric SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); 393*bdd1243dSDimitry Andric GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); 394*bdd1243dSDimitry Andric } 395*bdd1243dSDimitry Andric if (auto Res = SPIRV::lowerBuiltin( 396*bdd1243dSDimitry Andric DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder, 397*bdd1243dSDimitry Andric ResVReg, OrigRetTy, ArgVRegs, GR)) 398*bdd1243dSDimitry Andric return *Res; 399*bdd1243dSDimitry Andric } 400fcaf7f86SDimitry Andric if (CF && CF->isDeclaration() && 401fcaf7f86SDimitry Andric !GR->find(CF, &MIRBuilder.getMF()).isValid()) { 40281ad6265SDimitry Andric // Emit the type info and forward function declaration to the first MBB 40381ad6265SDimitry Andric // to ensure VReg definition dependencies are valid across all MBBs. 404fcaf7f86SDimitry Andric MachineIRBuilder FirstBlockBuilder; 405fcaf7f86SDimitry Andric FirstBlockBuilder.setMF(MF); 406fcaf7f86SDimitry Andric FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 40781ad6265SDimitry Andric 40881ad6265SDimitry Andric SmallVector<ArrayRef<Register>, 8> VRegArgs; 40981ad6265SDimitry Andric SmallVector<SmallVector<Register, 1>, 8> ToInsert; 41081ad6265SDimitry Andric for (const Argument &Arg : CF->args()) { 41181ad6265SDimitry Andric if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 41281ad6265SDimitry Andric continue; // Don't handle zero sized types. 413fcaf7f86SDimitry Andric ToInsert.push_back( 414fcaf7f86SDimitry Andric {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); 41581ad6265SDimitry Andric VRegArgs.push_back(ToInsert.back()); 41681ad6265SDimitry Andric } 417fcaf7f86SDimitry Andric // TODO: Reuse FunctionLoweringInfo 41881ad6265SDimitry Andric FunctionLoweringInfo FuncInfo; 419fcaf7f86SDimitry Andric lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 42081ad6265SDimitry Andric } 42181ad6265SDimitry Andric 42281ad6265SDimitry Andric // Make sure there's a valid return reg, even for functions returning void. 423fcaf7f86SDimitry Andric if (!ResVReg.isValid()) 42481ad6265SDimitry Andric ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 42581ad6265SDimitry Andric SPIRVType *RetType = 426fcaf7f86SDimitry Andric GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); 42781ad6265SDimitry Andric 42881ad6265SDimitry Andric // Emit the OpFunctionCall and its args. 42981ad6265SDimitry Andric auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) 43081ad6265SDimitry Andric .addDef(ResVReg) 43181ad6265SDimitry Andric .addUse(GR->getSPIRVTypeID(RetType)) 43281ad6265SDimitry Andric .add(Info.Callee); 43381ad6265SDimitry Andric 43481ad6265SDimitry Andric for (const auto &Arg : Info.OrigArgs) { 43581ad6265SDimitry Andric // Currently call args should have single vregs. 43681ad6265SDimitry Andric if (Arg.Regs.size() > 1) 43781ad6265SDimitry Andric return false; 43881ad6265SDimitry Andric MIB.addUse(Arg.Regs[0]); 43981ad6265SDimitry Andric } 440*bdd1243dSDimitry Andric return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), 441*bdd1243dSDimitry Andric *ST->getRegBankInfo()); 44281ad6265SDimitry Andric } 443