xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
181ad6265SDimitry Andric //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This file implements the lowering of LLVM calls to machine code calls for
1081ad6265SDimitry Andric // GlobalISel.
1181ad6265SDimitry Andric //
1281ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1381ad6265SDimitry Andric 
1481ad6265SDimitry Andric #include "SPIRVCallLowering.h"
1581ad6265SDimitry Andric #include "MCTargetDesc/SPIRVBaseInfo.h"
1681ad6265SDimitry Andric #include "SPIRV.h"
17bdd1243dSDimitry Andric #include "SPIRVBuiltins.h"
1881ad6265SDimitry Andric #include "SPIRVGlobalRegistry.h"
1981ad6265SDimitry Andric #include "SPIRVISelLowering.h"
2081ad6265SDimitry Andric #include "SPIRVRegisterInfo.h"
2181ad6265SDimitry Andric #include "SPIRVSubtarget.h"
2281ad6265SDimitry Andric #include "SPIRVUtils.h"
2381ad6265SDimitry Andric #include "llvm/CodeGen/FunctionLoweringInfo.h"
24bdd1243dSDimitry Andric #include "llvm/Support/ModRef.h"
2581ad6265SDimitry Andric 
2681ad6265SDimitry Andric using namespace llvm;
2781ad6265SDimitry Andric 
2881ad6265SDimitry Andric SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
2981ad6265SDimitry Andric                                      SPIRVGlobalRegistry *GR)
30fcaf7f86SDimitry Andric     : CallLowering(&TLI), GR(GR) {}
3181ad6265SDimitry Andric 
3281ad6265SDimitry Andric bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
3381ad6265SDimitry Andric                                     const Value *Val, ArrayRef<Register> VRegs,
3481ad6265SDimitry Andric                                     FunctionLoweringInfo &FLI,
3581ad6265SDimitry Andric                                     Register SwiftErrorVReg) const {
3681ad6265SDimitry Andric   // Currently all return types should use a single register.
3781ad6265SDimitry Andric   // TODO: handle the case of multiple registers.
3881ad6265SDimitry Andric   if (VRegs.size() > 1)
3981ad6265SDimitry Andric     return false;
40fcaf7f86SDimitry Andric   if (Val) {
41fcaf7f86SDimitry Andric     const auto &STI = MIRBuilder.getMF().getSubtarget();
4281ad6265SDimitry Andric     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
4381ad6265SDimitry Andric         .addUse(VRegs[0])
44fcaf7f86SDimitry Andric         .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
45fcaf7f86SDimitry Andric                           *STI.getRegBankInfo());
46fcaf7f86SDimitry Andric   }
4781ad6265SDimitry Andric   MIRBuilder.buildInstr(SPIRV::OpReturn);
4881ad6265SDimitry Andric   return true;
4981ad6265SDimitry Andric }
5081ad6265SDimitry Andric 
5181ad6265SDimitry Andric // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
5281ad6265SDimitry Andric static uint32_t getFunctionControl(const Function &F) {
53bdd1243dSDimitry Andric   MemoryEffects MemEffects = F.getMemoryEffects();
54bdd1243dSDimitry Andric 
5581ad6265SDimitry Andric   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
56bdd1243dSDimitry Andric 
57bdd1243dSDimitry Andric   if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
5881ad6265SDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
59bdd1243dSDimitry Andric   else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
60bdd1243dSDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
61bdd1243dSDimitry Andric 
62bdd1243dSDimitry Andric   if (MemEffects.doesNotAccessMemory())
63bdd1243dSDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
64bdd1243dSDimitry Andric   else if (MemEffects.onlyReadsMemory())
65bdd1243dSDimitry Andric     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
66bdd1243dSDimitry Andric 
6781ad6265SDimitry Andric   return FuncControl;
6881ad6265SDimitry Andric }
6981ad6265SDimitry Andric 
70fcaf7f86SDimitry Andric static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
71fcaf7f86SDimitry Andric   if (MD->getNumOperands() > NumOp) {
72fcaf7f86SDimitry Andric     auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
73fcaf7f86SDimitry Andric     if (CMeta)
74fcaf7f86SDimitry Andric       return dyn_cast<ConstantInt>(CMeta->getValue());
75fcaf7f86SDimitry Andric   }
76fcaf7f86SDimitry Andric   return nullptr;
77fcaf7f86SDimitry Andric }
78fcaf7f86SDimitry Andric 
79fcaf7f86SDimitry Andric // This code restores function args/retvalue types for composite cases
80fcaf7f86SDimitry Andric // because the final types should still be aggregate whereas they're i32
81fcaf7f86SDimitry Andric // during the translation to cope with aggregate flattening etc.
82fcaf7f86SDimitry Andric static FunctionType *getOriginalFunctionType(const Function &F) {
83fcaf7f86SDimitry Andric   auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
84fcaf7f86SDimitry Andric   if (NamedMD == nullptr)
85fcaf7f86SDimitry Andric     return F.getFunctionType();
86fcaf7f86SDimitry Andric 
87fcaf7f86SDimitry Andric   Type *RetTy = F.getFunctionType()->getReturnType();
88fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
89fcaf7f86SDimitry Andric   for (auto &Arg : F.args())
90fcaf7f86SDimitry Andric     ArgTypes.push_back(Arg.getType());
91fcaf7f86SDimitry Andric 
92fcaf7f86SDimitry Andric   auto ThisFuncMDIt =
93fcaf7f86SDimitry Andric       std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
94fcaf7f86SDimitry Andric         return isa<MDString>(N->getOperand(0)) &&
95fcaf7f86SDimitry Andric                cast<MDString>(N->getOperand(0))->getString() == F.getName();
96fcaf7f86SDimitry Andric       });
97fcaf7f86SDimitry Andric   // TODO: probably one function can have numerous type mutations,
98fcaf7f86SDimitry Andric   // so we should support this.
99fcaf7f86SDimitry Andric   if (ThisFuncMDIt != NamedMD->op_end()) {
100fcaf7f86SDimitry Andric     auto *ThisFuncMD = *ThisFuncMDIt;
101fcaf7f86SDimitry Andric     MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
102fcaf7f86SDimitry Andric     assert(MD && "MDNode operand is expected");
103fcaf7f86SDimitry Andric     ConstantInt *Const = getConstInt(MD, 0);
104fcaf7f86SDimitry Andric     if (Const) {
105fcaf7f86SDimitry Andric       auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
106fcaf7f86SDimitry Andric       assert(CMeta && "ConstantAsMetadata operand is expected");
107fcaf7f86SDimitry Andric       assert(Const->getSExtValue() >= -1);
108fcaf7f86SDimitry Andric       // Currently -1 indicates return value, greater values mean
109fcaf7f86SDimitry Andric       // argument numbers.
110fcaf7f86SDimitry Andric       if (Const->getSExtValue() == -1)
111fcaf7f86SDimitry Andric         RetTy = CMeta->getType();
112fcaf7f86SDimitry Andric       else
113fcaf7f86SDimitry Andric         ArgTypes[Const->getSExtValue()] = CMeta->getType();
114fcaf7f86SDimitry Andric     }
115fcaf7f86SDimitry Andric   }
116fcaf7f86SDimitry Andric 
117fcaf7f86SDimitry Andric   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
118fcaf7f86SDimitry Andric }
119fcaf7f86SDimitry Andric 
120bdd1243dSDimitry Andric static MDString *getKernelArgAttribute(const Function &KernelFunction,
121bdd1243dSDimitry Andric                                        unsigned ArgIdx,
122bdd1243dSDimitry Andric                                        const StringRef AttributeName) {
123bdd1243dSDimitry Andric   assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
124bdd1243dSDimitry Andric          "Kernel attributes are attached/belong only to kernel functions");
125bdd1243dSDimitry Andric 
126bdd1243dSDimitry Andric   // Lookup the argument attribute in metadata attached to the kernel function.
127bdd1243dSDimitry Andric   MDNode *Node = KernelFunction.getMetadata(AttributeName);
128bdd1243dSDimitry Andric   if (Node && ArgIdx < Node->getNumOperands())
129bdd1243dSDimitry Andric     return cast<MDString>(Node->getOperand(ArgIdx));
130bdd1243dSDimitry Andric 
131bdd1243dSDimitry Andric   // Sometimes metadata containing kernel attributes is not attached to the
132bdd1243dSDimitry Andric   // function, but can be found in the named module-level metadata instead.
133bdd1243dSDimitry Andric   // For example:
134bdd1243dSDimitry Andric   //   !opencl.kernels = !{!0}
135bdd1243dSDimitry Andric   //   !0 = !{void ()* @someKernelFunction, !1, ...}
136bdd1243dSDimitry Andric   //   !1 = !{!"kernel_arg_addr_space", ...}
137bdd1243dSDimitry Andric   // In this case the actual index of searched argument attribute is ArgIdx + 1,
138bdd1243dSDimitry Andric   // since the first metadata node operand is occupied by attribute name
139bdd1243dSDimitry Andric   // ("kernel_arg_addr_space" in the example above).
140bdd1243dSDimitry Andric   unsigned MDArgIdx = ArgIdx + 1;
141bdd1243dSDimitry Andric   NamedMDNode *OpenCLKernelsMD =
142bdd1243dSDimitry Andric       KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
143bdd1243dSDimitry Andric   if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
144bdd1243dSDimitry Andric     return nullptr;
145bdd1243dSDimitry Andric 
146bdd1243dSDimitry Andric   // KernelToMDNodeList contains kernel function declarations followed by
147bdd1243dSDimitry Andric   // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
148bdd1243dSDimitry Andric   // to the currently lowered kernel function.
149bdd1243dSDimitry Andric   MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
150bdd1243dSDimitry Andric   bool FoundLoweredKernelFunction = false;
151bdd1243dSDimitry Andric   for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
152bdd1243dSDimitry Andric     ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
153bdd1243dSDimitry Andric     if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
154bdd1243dSDimitry Andric                           KernelFunction.getName()) {
155bdd1243dSDimitry Andric       FoundLoweredKernelFunction = true;
156bdd1243dSDimitry Andric       continue;
157bdd1243dSDimitry Andric     }
158bdd1243dSDimitry Andric     if (MaybeValue && FoundLoweredKernelFunction)
159bdd1243dSDimitry Andric       return nullptr;
160bdd1243dSDimitry Andric 
161bdd1243dSDimitry Andric     MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
162bdd1243dSDimitry Andric     if (FoundLoweredKernelFunction && MaybeNode &&
163bdd1243dSDimitry Andric         cast<MDString>(MaybeNode->getOperand(0))->getString() ==
164bdd1243dSDimitry Andric             AttributeName &&
165bdd1243dSDimitry Andric         MDArgIdx < MaybeNode->getNumOperands())
166bdd1243dSDimitry Andric       return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
167bdd1243dSDimitry Andric   }
168bdd1243dSDimitry Andric   return nullptr;
169bdd1243dSDimitry Andric }
170bdd1243dSDimitry Andric 
171bdd1243dSDimitry Andric static SPIRV::AccessQualifier::AccessQualifier
172bdd1243dSDimitry Andric getArgAccessQual(const Function &F, unsigned ArgIdx) {
173bdd1243dSDimitry Andric   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
174bdd1243dSDimitry Andric     return SPIRV::AccessQualifier::ReadWrite;
175bdd1243dSDimitry Andric 
176bdd1243dSDimitry Andric   MDString *ArgAttribute =
177bdd1243dSDimitry Andric       getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
178bdd1243dSDimitry Andric   if (!ArgAttribute)
179bdd1243dSDimitry Andric     return SPIRV::AccessQualifier::ReadWrite;
180bdd1243dSDimitry Andric 
181bdd1243dSDimitry Andric   if (ArgAttribute->getString().compare("read_only") == 0)
182bdd1243dSDimitry Andric     return SPIRV::AccessQualifier::ReadOnly;
183bdd1243dSDimitry Andric   if (ArgAttribute->getString().compare("write_only") == 0)
184bdd1243dSDimitry Andric     return SPIRV::AccessQualifier::WriteOnly;
185bdd1243dSDimitry Andric   return SPIRV::AccessQualifier::ReadWrite;
186bdd1243dSDimitry Andric }
187bdd1243dSDimitry Andric 
188bdd1243dSDimitry Andric static std::vector<SPIRV::Decoration::Decoration>
189bdd1243dSDimitry Andric getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
190bdd1243dSDimitry Andric   MDString *ArgAttribute =
191bdd1243dSDimitry Andric       getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
192bdd1243dSDimitry Andric   if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
193bdd1243dSDimitry Andric     return {SPIRV::Decoration::Volatile};
194bdd1243dSDimitry Andric   return {};
195bdd1243dSDimitry Andric }
196bdd1243dSDimitry Andric 
1975f757f3fSDimitry Andric static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
1985f757f3fSDimitry Andric                                   SPIRVGlobalRegistry *GR,
1995f757f3fSDimitry Andric                                   MachineIRBuilder &MIRBuilder) {
2005f757f3fSDimitry Andric   // Read argument's access qualifier from metadata or default.
2015f757f3fSDimitry Andric   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
2025f757f3fSDimitry Andric       getArgAccessQual(F, ArgIdx);
2035f757f3fSDimitry Andric 
204bdd1243dSDimitry Andric   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
2055f757f3fSDimitry Andric 
2065f757f3fSDimitry Andric   // In case of non-kernel SPIR-V function or already TargetExtType, use the
2075f757f3fSDimitry Andric   // original IR type.
208bdd1243dSDimitry Andric   if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
209bdd1243dSDimitry Andric       isSpecialOpaqueType(OriginalArgType))
2105f757f3fSDimitry Andric     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
211bdd1243dSDimitry Andric 
212bdd1243dSDimitry Andric   MDString *MDKernelArgType =
213bdd1243dSDimitry Andric       getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
214*1db9f3b2SDimitry Andric   if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
215*1db9f3b2SDimitry Andric                            !MDKernelArgType->getString().ends_with("_t")))
2165f757f3fSDimitry Andric     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
217bdd1243dSDimitry Andric 
2185f757f3fSDimitry Andric   if (MDKernelArgType->getString().ends_with("*"))
2195f757f3fSDimitry Andric     return GR->getOrCreateSPIRVTypeByName(
2205f757f3fSDimitry Andric         MDKernelArgType->getString(), MIRBuilder,
2215f757f3fSDimitry Andric         addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));
2225f757f3fSDimitry Andric 
2235f757f3fSDimitry Andric   if (MDKernelArgType->getString().ends_with("_t"))
2245f757f3fSDimitry Andric     return GR->getOrCreateSPIRVTypeByName(
2255f757f3fSDimitry Andric         "opencl." + MDKernelArgType->getString().str(), MIRBuilder,
2265f757f3fSDimitry Andric         SPIRV::StorageClass::Function, ArgAccessQual);
2275f757f3fSDimitry Andric 
2285f757f3fSDimitry Andric   llvm_unreachable("Unable to recognize argument type name.");
2295f757f3fSDimitry Andric }
2305f757f3fSDimitry Andric 
2315f757f3fSDimitry Andric static bool isEntryPoint(const Function &F) {
2325f757f3fSDimitry Andric   // OpenCL handling: any function with the SPIR_KERNEL
2335f757f3fSDimitry Andric   // calling convention will be a potential entry point.
2345f757f3fSDimitry Andric   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
2355f757f3fSDimitry Andric     return true;
2365f757f3fSDimitry Andric 
2375f757f3fSDimitry Andric   // HLSL handling: special attribute are emitted from the
2385f757f3fSDimitry Andric   // front-end.
2395f757f3fSDimitry Andric   if (F.getFnAttribute("hlsl.shader").isValid())
2405f757f3fSDimitry Andric     return true;
2415f757f3fSDimitry Andric 
2425f757f3fSDimitry Andric   return false;
2435f757f3fSDimitry Andric }
2445f757f3fSDimitry Andric 
2455f757f3fSDimitry Andric static SPIRV::ExecutionModel::ExecutionModel
2465f757f3fSDimitry Andric getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
2475f757f3fSDimitry Andric   if (STI.isOpenCLEnv())
2485f757f3fSDimitry Andric     return SPIRV::ExecutionModel::Kernel;
2495f757f3fSDimitry Andric 
2505f757f3fSDimitry Andric   auto attribute = F.getFnAttribute("hlsl.shader");
2515f757f3fSDimitry Andric   if (!attribute.isValid()) {
2525f757f3fSDimitry Andric     report_fatal_error(
2535f757f3fSDimitry Andric         "This entry point lacks mandatory hlsl.shader attribute.");
2545f757f3fSDimitry Andric   }
2555f757f3fSDimitry Andric 
2565f757f3fSDimitry Andric   const auto value = attribute.getValueAsString();
2575f757f3fSDimitry Andric   if (value == "compute")
2585f757f3fSDimitry Andric     return SPIRV::ExecutionModel::GLCompute;
2595f757f3fSDimitry Andric 
2605f757f3fSDimitry Andric   report_fatal_error("This HLSL entry point is not supported by this backend.");
261bdd1243dSDimitry Andric }
262bdd1243dSDimitry Andric 
26381ad6265SDimitry Andric bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
26481ad6265SDimitry Andric                                              const Function &F,
26581ad6265SDimitry Andric                                              ArrayRef<ArrayRef<Register>> VRegs,
26681ad6265SDimitry Andric                                              FunctionLoweringInfo &FLI) const {
26781ad6265SDimitry Andric   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
268753f127fSDimitry Andric   GR->setCurrentFunc(MIRBuilder.getMF());
26981ad6265SDimitry Andric 
27081ad6265SDimitry Andric   // Assign types and names to all args, and store their types for later.
271fcaf7f86SDimitry Andric   FunctionType *FTy = getOriginalFunctionType(F);
272fcaf7f86SDimitry Andric   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
27381ad6265SDimitry Andric   if (VRegs.size() > 0) {
27481ad6265SDimitry Andric     unsigned i = 0;
27581ad6265SDimitry Andric     for (const auto &Arg : F.args()) {
27681ad6265SDimitry Andric       // Currently formal args should use single registers.
27781ad6265SDimitry Andric       // TODO: handle the case of multiple registers.
27881ad6265SDimitry Andric       if (VRegs[i].size() > 1)
27981ad6265SDimitry Andric         return false;
2805f757f3fSDimitry Andric       auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
2815f757f3fSDimitry Andric       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
282fcaf7f86SDimitry Andric       ArgTypeVRegs.push_back(SpirvTy);
28381ad6265SDimitry Andric 
28481ad6265SDimitry Andric       if (Arg.hasName())
28581ad6265SDimitry Andric         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
28681ad6265SDimitry Andric       if (Arg.getType()->isPointerTy()) {
28781ad6265SDimitry Andric         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
28881ad6265SDimitry Andric         if (DerefBytes != 0)
28981ad6265SDimitry Andric           buildOpDecorate(VRegs[i][0], MIRBuilder,
29081ad6265SDimitry Andric                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
29181ad6265SDimitry Andric       }
29281ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::Alignment)) {
293fcaf7f86SDimitry Andric         auto Alignment = static_cast<unsigned>(
294fcaf7f86SDimitry Andric             Arg.getAttribute(Attribute::Alignment).getValueAsInt());
29581ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
296fcaf7f86SDimitry Andric                         {Alignment});
29781ad6265SDimitry Andric       }
29881ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::ReadOnly)) {
29981ad6265SDimitry Andric         auto Attr =
30081ad6265SDimitry Andric             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
30181ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder,
30281ad6265SDimitry Andric                         SPIRV::Decoration::FuncParamAttr, {Attr});
30381ad6265SDimitry Andric       }
30481ad6265SDimitry Andric       if (Arg.hasAttribute(Attribute::ZExt)) {
30581ad6265SDimitry Andric         auto Attr =
30681ad6265SDimitry Andric             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
30781ad6265SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder,
30881ad6265SDimitry Andric                         SPIRV::Decoration::FuncParamAttr, {Attr});
30981ad6265SDimitry Andric       }
310fcaf7f86SDimitry Andric       if (Arg.hasAttribute(Attribute::NoAlias)) {
311fcaf7f86SDimitry Andric         auto Attr =
312fcaf7f86SDimitry Andric             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
313fcaf7f86SDimitry Andric         buildOpDecorate(VRegs[i][0], MIRBuilder,
314fcaf7f86SDimitry Andric                         SPIRV::Decoration::FuncParamAttr, {Attr});
315fcaf7f86SDimitry Andric       }
316bdd1243dSDimitry Andric 
317bdd1243dSDimitry Andric       if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
318bdd1243dSDimitry Andric         std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
319bdd1243dSDimitry Andric             getKernelArgTypeQual(F, i);
320bdd1243dSDimitry Andric         for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
321bdd1243dSDimitry Andric           buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
322fcaf7f86SDimitry Andric       }
323bdd1243dSDimitry Andric 
324bdd1243dSDimitry Andric       MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
325fcaf7f86SDimitry Andric       if (Node && i < Node->getNumOperands() &&
326fcaf7f86SDimitry Andric           isa<MDNode>(Node->getOperand(i))) {
327fcaf7f86SDimitry Andric         MDNode *MD = cast<MDNode>(Node->getOperand(i));
328fcaf7f86SDimitry Andric         for (const MDOperand &MDOp : MD->operands()) {
329fcaf7f86SDimitry Andric           MDNode *MD2 = dyn_cast<MDNode>(MDOp);
330fcaf7f86SDimitry Andric           assert(MD2 && "Metadata operand is expected");
331fcaf7f86SDimitry Andric           ConstantInt *Const = getConstInt(MD2, 0);
332fcaf7f86SDimitry Andric           assert(Const && "MDOperand should be ConstantInt");
333bdd1243dSDimitry Andric           auto Dec =
334bdd1243dSDimitry Andric               static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
335fcaf7f86SDimitry Andric           std::vector<uint32_t> DecVec;
336fcaf7f86SDimitry Andric           for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
337fcaf7f86SDimitry Andric             ConstantInt *Const = getConstInt(MD2, j);
338fcaf7f86SDimitry Andric             assert(Const && "MDOperand should be ConstantInt");
339fcaf7f86SDimitry Andric             DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
340fcaf7f86SDimitry Andric           }
341fcaf7f86SDimitry Andric           buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
342fcaf7f86SDimitry Andric         }
343fcaf7f86SDimitry Andric       }
34481ad6265SDimitry Andric       ++i;
34581ad6265SDimitry Andric     }
34681ad6265SDimitry Andric   }
34781ad6265SDimitry Andric 
34881ad6265SDimitry Andric   // Generate a SPIR-V type for the function.
34981ad6265SDimitry Andric   auto MRI = MIRBuilder.getMRI();
35081ad6265SDimitry Andric   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
35181ad6265SDimitry Andric   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
352753f127fSDimitry Andric   if (F.isDeclaration())
353753f127fSDimitry Andric     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
354fcaf7f86SDimitry Andric   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
355fcaf7f86SDimitry Andric   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
356fcaf7f86SDimitry Andric       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
35781ad6265SDimitry Andric 
35881ad6265SDimitry Andric   // Build the OpTypeFunction declaring it.
35981ad6265SDimitry Andric   uint32_t FuncControl = getFunctionControl(F);
36081ad6265SDimitry Andric 
36181ad6265SDimitry Andric   MIRBuilder.buildInstr(SPIRV::OpFunction)
36281ad6265SDimitry Andric       .addDef(FuncVReg)
363fcaf7f86SDimitry Andric       .addUse(GR->getSPIRVTypeID(RetTy))
36481ad6265SDimitry Andric       .addImm(FuncControl)
36581ad6265SDimitry Andric       .addUse(GR->getSPIRVTypeID(FuncTy));
36681ad6265SDimitry Andric 
36781ad6265SDimitry Andric   // Add OpFunctionParameters.
368fcaf7f86SDimitry Andric   int i = 0;
369fcaf7f86SDimitry Andric   for (const auto &Arg : F.args()) {
37081ad6265SDimitry Andric     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
37181ad6265SDimitry Andric     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
37281ad6265SDimitry Andric     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
37381ad6265SDimitry Andric         .addDef(VRegs[i][0])
374fcaf7f86SDimitry Andric         .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
375753f127fSDimitry Andric     if (F.isDeclaration())
376fcaf7f86SDimitry Andric       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
377fcaf7f86SDimitry Andric     i++;
37881ad6265SDimitry Andric   }
37981ad6265SDimitry Andric   // Name the function.
38081ad6265SDimitry Andric   if (F.hasName())
38181ad6265SDimitry Andric     buildOpName(FuncVReg, F.getName(), MIRBuilder);
38281ad6265SDimitry Andric 
38381ad6265SDimitry Andric   // Handle entry points and function linkage.
3845f757f3fSDimitry Andric   if (isEntryPoint(F)) {
3855f757f3fSDimitry Andric     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
3865f757f3fSDimitry Andric     auto executionModel = getExecutionModel(STI, F);
38781ad6265SDimitry Andric     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
3885f757f3fSDimitry Andric                    .addImm(static_cast<uint32_t>(executionModel))
38981ad6265SDimitry Andric                    .addUse(FuncVReg);
39081ad6265SDimitry Andric     addStringImm(F.getName(), MIB);
39181ad6265SDimitry Andric   } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
39281ad6265SDimitry Andric              F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
39381ad6265SDimitry Andric     auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
39481ad6265SDimitry Andric                                    : SPIRV::LinkageType::Export;
39581ad6265SDimitry Andric     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
39681ad6265SDimitry Andric                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
39781ad6265SDimitry Andric   }
39881ad6265SDimitry Andric 
39981ad6265SDimitry Andric   return true;
40081ad6265SDimitry Andric }
40181ad6265SDimitry Andric 
40281ad6265SDimitry Andric bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
40381ad6265SDimitry Andric                                   CallLoweringInfo &Info) const {
40481ad6265SDimitry Andric   // Currently call returns should have single vregs.
40581ad6265SDimitry Andric   // TODO: handle the case of multiple registers.
40681ad6265SDimitry Andric   if (Info.OrigRet.Regs.size() > 1)
40781ad6265SDimitry Andric     return false;
408fcaf7f86SDimitry Andric   MachineFunction &MF = MIRBuilder.getMF();
409fcaf7f86SDimitry Andric   GR->setCurrentFunc(MF);
410fcaf7f86SDimitry Andric   FunctionType *FTy = nullptr;
411fcaf7f86SDimitry Andric   const Function *CF = nullptr;
41281ad6265SDimitry Andric 
41381ad6265SDimitry Andric   // Emit a regular OpFunctionCall. If it's an externally declared function,
414fcaf7f86SDimitry Andric   // be sure to emit its type and function declaration here. It will be hoisted
415fcaf7f86SDimitry Andric   // globally later.
41681ad6265SDimitry Andric   if (Info.Callee.isGlobal()) {
417fcaf7f86SDimitry Andric     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
41881ad6265SDimitry Andric     // TODO: support constexpr casts and indirect calls.
41981ad6265SDimitry Andric     if (CF == nullptr)
42081ad6265SDimitry Andric       return false;
421fcaf7f86SDimitry Andric     FTy = getOriginalFunctionType(*CF);
422fcaf7f86SDimitry Andric   }
423fcaf7f86SDimitry Andric 
42406c3fb27SDimitry Andric   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
425fcaf7f86SDimitry Andric   Register ResVReg =
426fcaf7f86SDimitry Andric       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
427bdd1243dSDimitry Andric   std::string FuncName = Info.Callee.getGlobal()->getName().str();
428bdd1243dSDimitry Andric   std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
429bdd1243dSDimitry Andric   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
430bdd1243dSDimitry Andric   // TODO: check that it's OCL builtin, then apply OpenCL_std.
431bdd1243dSDimitry Andric   if (!DemangledName.empty() && CF && CF->isDeclaration() &&
432bdd1243dSDimitry Andric       ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
433bdd1243dSDimitry Andric     const Type *OrigRetTy = Info.OrigRet.Ty;
434bdd1243dSDimitry Andric     if (FTy)
435bdd1243dSDimitry Andric       OrigRetTy = FTy->getReturnType();
436bdd1243dSDimitry Andric     SmallVector<Register, 8> ArgVRegs;
437bdd1243dSDimitry Andric     for (auto Arg : Info.OrigArgs) {
438bdd1243dSDimitry Andric       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
439bdd1243dSDimitry Andric       ArgVRegs.push_back(Arg.Regs[0]);
440bdd1243dSDimitry Andric       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
441*1db9f3b2SDimitry Andric       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
442bdd1243dSDimitry Andric         GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
443bdd1243dSDimitry Andric     }
444bdd1243dSDimitry Andric     if (auto Res = SPIRV::lowerBuiltin(
445bdd1243dSDimitry Andric             DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
446bdd1243dSDimitry Andric             ResVReg, OrigRetTy, ArgVRegs, GR))
447bdd1243dSDimitry Andric       return *Res;
448bdd1243dSDimitry Andric   }
449fcaf7f86SDimitry Andric   if (CF && CF->isDeclaration() &&
450fcaf7f86SDimitry Andric       !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
45181ad6265SDimitry Andric     // Emit the type info and forward function declaration to the first MBB
45281ad6265SDimitry Andric     // to ensure VReg definition dependencies are valid across all MBBs.
453fcaf7f86SDimitry Andric     MachineIRBuilder FirstBlockBuilder;
454fcaf7f86SDimitry Andric     FirstBlockBuilder.setMF(MF);
455fcaf7f86SDimitry Andric     FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
45681ad6265SDimitry Andric 
45781ad6265SDimitry Andric     SmallVector<ArrayRef<Register>, 8> VRegArgs;
45881ad6265SDimitry Andric     SmallVector<SmallVector<Register, 1>, 8> ToInsert;
45981ad6265SDimitry Andric     for (const Argument &Arg : CF->args()) {
46081ad6265SDimitry Andric       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
46181ad6265SDimitry Andric         continue; // Don't handle zero sized types.
46206c3fb27SDimitry Andric       Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
46306c3fb27SDimitry Andric       MRI->setRegClass(Reg, &SPIRV::IDRegClass);
46406c3fb27SDimitry Andric       ToInsert.push_back({Reg});
46581ad6265SDimitry Andric       VRegArgs.push_back(ToInsert.back());
46681ad6265SDimitry Andric     }
467fcaf7f86SDimitry Andric     // TODO: Reuse FunctionLoweringInfo
46881ad6265SDimitry Andric     FunctionLoweringInfo FuncInfo;
469fcaf7f86SDimitry Andric     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
47081ad6265SDimitry Andric   }
47181ad6265SDimitry Andric 
47281ad6265SDimitry Andric   // Make sure there's a valid return reg, even for functions returning void.
473fcaf7f86SDimitry Andric   if (!ResVReg.isValid())
47481ad6265SDimitry Andric     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
47581ad6265SDimitry Andric   SPIRVType *RetType =
476fcaf7f86SDimitry Andric       GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
47781ad6265SDimitry Andric 
47881ad6265SDimitry Andric   // Emit the OpFunctionCall and its args.
47981ad6265SDimitry Andric   auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
48081ad6265SDimitry Andric                  .addDef(ResVReg)
48181ad6265SDimitry Andric                  .addUse(GR->getSPIRVTypeID(RetType))
48281ad6265SDimitry Andric                  .add(Info.Callee);
48381ad6265SDimitry Andric 
48481ad6265SDimitry Andric   for (const auto &Arg : Info.OrigArgs) {
48581ad6265SDimitry Andric     // Currently call args should have single vregs.
48681ad6265SDimitry Andric     if (Arg.Regs.size() > 1)
48781ad6265SDimitry Andric       return false;
48881ad6265SDimitry Andric     MIB.addUse(Arg.Regs[0]);
48981ad6265SDimitry Andric   }
490bdd1243dSDimitry Andric   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
491bdd1243dSDimitry Andric                               *ST->getRegBankInfo());
49281ad6265SDimitry Andric }
493