1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVBuiltins.h" 18 #include "SPIRVGlobalRegistry.h" 19 #include "SPIRVISelLowering.h" 20 #include "SPIRVMetadata.h" 21 #include "SPIRVRegisterInfo.h" 22 #include "SPIRVSubtarget.h" 23 #include "SPIRVUtils.h" 24 #include "llvm/CodeGen/FunctionLoweringInfo.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/IntrinsicsSPIRV.h" 27 #include "llvm/Support/ModRef.h" 28 29 using namespace llvm; 30 31 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 32 SPIRVGlobalRegistry *GR) 33 : CallLowering(&TLI), GR(GR) {} 34 35 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 36 const Value *Val, ArrayRef<Register> VRegs, 37 FunctionLoweringInfo &FLI, 38 Register SwiftErrorVReg) const { 39 // Ignore if called from the internal service function 40 if (MIRBuilder.getMF() 41 .getFunction() 42 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) 43 .isValid()) 44 return true; 45 46 // Maybe run postponed production of types for function pointers 47 if (IndirectCalls.size() > 0) { 48 produceIndirectPtrTypes(MIRBuilder); 49 IndirectCalls.clear(); 50 } 51 52 // Currently all return types should use a single register. 53 // TODO: handle the case of multiple registers. 54 if (VRegs.size() > 1) 55 return false; 56 if (Val) { 57 const auto &STI = MIRBuilder.getMF().getSubtarget(); 58 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 59 .addUse(VRegs[0]) 60 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), 61 *STI.getRegBankInfo()); 62 } 63 MIRBuilder.buildInstr(SPIRV::OpReturn); 64 return true; 65 } 66 67 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 68 static uint32_t getFunctionControl(const Function &F, 69 const SPIRVSubtarget *ST) { 70 MemoryEffects MemEffects = F.getMemoryEffects(); 71 72 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 73 74 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) 75 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 76 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) 77 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 78 79 if (MemEffects.doesNotAccessMemory()) 80 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 81 else if (MemEffects.onlyReadsMemory()) 82 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 83 84 if (ST->canUseExtension(SPIRV::Extension::SPV_INTEL_optnone) || 85 ST->canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) 86 if (F.hasFnAttribute(Attribute::OptimizeNone)) 87 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::OptNoneEXT); 88 89 return FuncControl; 90 } 91 92 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { 93 if (MD->getNumOperands() > NumOp) { 94 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); 95 if (CMeta) 96 return dyn_cast<ConstantInt>(CMeta->getValue()); 97 } 98 return nullptr; 99 } 100 101 // If the function has pointer arguments, we are forced to re-create this 102 // function type from the very beginning, changing PointerType by 103 // TypedPointerType for each pointer argument. Otherwise, the same `Type*` 104 // potentially corresponds to different SPIR-V function type, effectively 105 // invalidating logic behind global registry and duplicates tracker. 106 static FunctionType * 107 fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F, 108 FunctionType *FTy, const SPIRVType *SRetTy, 109 const SmallVector<SPIRVType *, 4> &SArgTys) { 110 bool hasArgPtrs = false; 111 for (auto &Arg : F.args()) { 112 // check if it's an instance of a non-typed PointerType 113 if (Arg.getType()->isPointerTy()) { 114 hasArgPtrs = true; 115 break; 116 } 117 } 118 if (!hasArgPtrs) { 119 Type *RetTy = FTy->getReturnType(); 120 // check if it's an instance of a non-typed PointerType 121 if (!RetTy->isPointerTy()) 122 return FTy; 123 } 124 125 // re-create function type, using TypedPointerType instead of PointerType to 126 // properly trace argument types 127 const Type *RetTy = GR->getTypeForSPIRVType(SRetTy); 128 SmallVector<Type *, 4> ArgTys; 129 for (auto SArgTy : SArgTys) 130 ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy))); 131 return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false); 132 } 133 134 // This code restores function args/retvalue types for composite cases 135 // because the final types should still be aggregate whereas they're i32 136 // during the translation to cope with aggregate flattening etc. 137 static FunctionType *getOriginalFunctionType(const Function &F) { 138 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); 139 if (NamedMD == nullptr) 140 return F.getFunctionType(); 141 142 Type *RetTy = F.getFunctionType()->getReturnType(); 143 SmallVector<Type *, 4> ArgTypes; 144 for (auto &Arg : F.args()) 145 ArgTypes.push_back(Arg.getType()); 146 147 auto ThisFuncMDIt = 148 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { 149 return isa<MDString>(N->getOperand(0)) && 150 cast<MDString>(N->getOperand(0))->getString() == F.getName(); 151 }); 152 // TODO: probably one function can have numerous type mutations, 153 // so we should support this. 154 if (ThisFuncMDIt != NamedMD->op_end()) { 155 auto *ThisFuncMD = *ThisFuncMDIt; 156 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); 157 assert(MD && "MDNode operand is expected"); 158 ConstantInt *Const = getConstInt(MD, 0); 159 if (Const) { 160 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); 161 assert(CMeta && "ConstantAsMetadata operand is expected"); 162 assert(Const->getSExtValue() >= -1); 163 // Currently -1 indicates return value, greater values mean 164 // argument numbers. 165 if (Const->getSExtValue() == -1) 166 RetTy = CMeta->getType(); 167 else 168 ArgTypes[Const->getSExtValue()] = CMeta->getType(); 169 } 170 } 171 172 return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); 173 } 174 175 static SPIRV::AccessQualifier::AccessQualifier 176 getArgAccessQual(const Function &F, unsigned ArgIdx) { 177 if (F.getCallingConv() != CallingConv::SPIR_KERNEL) 178 return SPIRV::AccessQualifier::ReadWrite; 179 180 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx); 181 if (!ArgAttribute) 182 return SPIRV::AccessQualifier::ReadWrite; 183 184 if (ArgAttribute->getString() == "read_only") 185 return SPIRV::AccessQualifier::ReadOnly; 186 if (ArgAttribute->getString() == "write_only") 187 return SPIRV::AccessQualifier::WriteOnly; 188 return SPIRV::AccessQualifier::ReadWrite; 189 } 190 191 static std::vector<SPIRV::Decoration::Decoration> 192 getKernelArgTypeQual(const Function &F, unsigned ArgIdx) { 193 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx); 194 if (ArgAttribute && ArgAttribute->getString() == "volatile") 195 return {SPIRV::Decoration::Volatile}; 196 return {}; 197 } 198 199 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, 200 SPIRVGlobalRegistry *GR, 201 MachineIRBuilder &MIRBuilder, 202 const SPIRVSubtarget &ST) { 203 // Read argument's access qualifier from metadata or default. 204 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = 205 getArgAccessQual(F, ArgIdx); 206 207 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); 208 209 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot 210 // be legally reassigned later). 211 if (!isPointerTy(OriginalArgType)) 212 return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); 213 214 Argument *Arg = F.getArg(ArgIdx); 215 Type *ArgType = Arg->getType(); 216 if (isTypedPointerTy(ArgType)) { 217 SPIRVType *ElementType = GR->getOrCreateSPIRVType( 218 cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder); 219 return GR->getOrCreateSPIRVPointerType( 220 ElementType, MIRBuilder, 221 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); 222 } 223 224 // In case OriginalArgType is of untyped pointer type, there are three 225 // possibilities: 226 // 1) This is a pointer of an LLVM IR element type, passed byval/byref. 227 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type 228 // intrinsic assigning a TargetExtType. 229 // 3) This is a pointer, try to retrieve pointer element type from a 230 // spv_assign_ptr_type intrinsic or otherwise use default pointer element 231 // type. 232 if (hasPointeeTypeAttr(Arg)) { 233 SPIRVType *ElementType = 234 GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder); 235 return GR->getOrCreateSPIRVPointerType( 236 ElementType, MIRBuilder, 237 addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); 238 } 239 240 for (auto User : Arg->users()) { 241 auto *II = dyn_cast<IntrinsicInst>(User); 242 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type. 243 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) { 244 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1)); 245 Type *BuiltinType = 246 cast<ConstantAsMetadata>(VMD->getMetadata())->getType(); 247 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType"); 248 return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual); 249 } 250 251 // Check if this is spv_assign_ptr_type assigning pointer element type. 252 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type) 253 continue; 254 255 MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1)); 256 Type *ElementTy = 257 toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType()); 258 SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder); 259 return GR->getOrCreateSPIRVPointerType( 260 ElementType, MIRBuilder, 261 addressSpaceToStorageClass( 262 cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST)); 263 } 264 265 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to 266 // LLVM types in a consistent manner 267 return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder, 268 ArgAccessQual); 269 } 270 271 static SPIRV::ExecutionModel::ExecutionModel 272 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) { 273 if (STI.isOpenCLEnv()) 274 return SPIRV::ExecutionModel::Kernel; 275 276 auto attribute = F.getFnAttribute("hlsl.shader"); 277 if (!attribute.isValid()) { 278 report_fatal_error( 279 "This entry point lacks mandatory hlsl.shader attribute."); 280 } 281 282 const auto value = attribute.getValueAsString(); 283 if (value == "compute") 284 return SPIRV::ExecutionModel::GLCompute; 285 286 report_fatal_error("This HLSL entry point is not supported by this backend."); 287 } 288 289 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 290 const Function &F, 291 ArrayRef<ArrayRef<Register>> VRegs, 292 FunctionLoweringInfo &FLI) const { 293 // Discard the internal service function 294 if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid()) 295 return true; 296 297 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 298 GR->setCurrentFunc(MIRBuilder.getMF()); 299 300 // Get access to information about available extensions 301 const SPIRVSubtarget *ST = 302 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget()); 303 304 // Assign types and names to all args, and store their types for later. 305 SmallVector<SPIRVType *, 4> ArgTypeVRegs; 306 if (VRegs.size() > 0) { 307 unsigned i = 0; 308 for (const auto &Arg : F.args()) { 309 // Currently formal args should use single registers. 310 // TODO: handle the case of multiple registers. 311 if (VRegs[i].size() > 1) 312 return false; 313 auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST); 314 GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF()); 315 ArgTypeVRegs.push_back(SpirvTy); 316 317 if (Arg.hasName()) 318 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 319 if (isPointerTyOrWrapper(Arg.getType())) { 320 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 321 if (DerefBytes != 0) 322 buildOpDecorate(VRegs[i][0], MIRBuilder, 323 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 324 } 325 if (Arg.hasAttribute(Attribute::Alignment)) { 326 auto Alignment = static_cast<unsigned>( 327 Arg.getAttribute(Attribute::Alignment).getValueAsInt()); 328 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 329 {Alignment}); 330 } 331 if (Arg.hasAttribute(Attribute::ReadOnly)) { 332 auto Attr = 333 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 334 buildOpDecorate(VRegs[i][0], MIRBuilder, 335 SPIRV::Decoration::FuncParamAttr, {Attr}); 336 } 337 if (Arg.hasAttribute(Attribute::ZExt)) { 338 auto Attr = 339 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 340 buildOpDecorate(VRegs[i][0], MIRBuilder, 341 SPIRV::Decoration::FuncParamAttr, {Attr}); 342 } 343 if (Arg.hasAttribute(Attribute::NoAlias)) { 344 auto Attr = 345 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); 346 buildOpDecorate(VRegs[i][0], MIRBuilder, 347 SPIRV::Decoration::FuncParamAttr, {Attr}); 348 } 349 if (Arg.hasAttribute(Attribute::ByVal)) { 350 auto Attr = 351 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal); 352 buildOpDecorate(VRegs[i][0], MIRBuilder, 353 SPIRV::Decoration::FuncParamAttr, {Attr}); 354 } 355 if (Arg.hasAttribute(Attribute::StructRet)) { 356 auto Attr = 357 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Sret); 358 buildOpDecorate(VRegs[i][0], MIRBuilder, 359 SPIRV::Decoration::FuncParamAttr, {Attr}); 360 } 361 362 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 363 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs = 364 getKernelArgTypeQual(F, i); 365 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) 366 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); 367 } 368 369 MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); 370 if (Node && i < Node->getNumOperands() && 371 isa<MDNode>(Node->getOperand(i))) { 372 MDNode *MD = cast<MDNode>(Node->getOperand(i)); 373 for (const MDOperand &MDOp : MD->operands()) { 374 MDNode *MD2 = dyn_cast<MDNode>(MDOp); 375 assert(MD2 && "Metadata operand is expected"); 376 ConstantInt *Const = getConstInt(MD2, 0); 377 assert(Const && "MDOperand should be ConstantInt"); 378 auto Dec = 379 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue()); 380 std::vector<uint32_t> DecVec; 381 for (unsigned j = 1; j < MD2->getNumOperands(); j++) { 382 ConstantInt *Const = getConstInt(MD2, j); 383 assert(Const && "MDOperand should be ConstantInt"); 384 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); 385 } 386 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); 387 } 388 } 389 ++i; 390 } 391 } 392 393 auto MRI = MIRBuilder.getMRI(); 394 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64)); 395 MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass); 396 if (F.isDeclaration()) 397 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 398 FunctionType *FTy = getOriginalFunctionType(F); 399 Type *FRetTy = FTy->getReturnType(); 400 if (isUntypedPointerTy(FRetTy)) { 401 if (Type *FRetElemTy = GR->findDeducedElementType(&F)) { 402 TypedPointerType *DerivedTy = TypedPointerType::get( 403 toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy)); 404 GR->addReturnType(&F, DerivedTy); 405 FRetTy = DerivedTy; 406 } 407 } 408 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder); 409 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs); 410 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 411 FTy, RetTy, ArgTypeVRegs, MIRBuilder); 412 uint32_t FuncControl = getFunctionControl(F, ST); 413 414 // Add OpFunction instruction 415 MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction) 416 .addDef(FuncVReg) 417 .addUse(GR->getSPIRVTypeID(RetTy)) 418 .addImm(FuncControl) 419 .addUse(GR->getSPIRVTypeID(FuncTy)); 420 GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0)); 421 GR->addGlobalObject(&F, &MIRBuilder.getMF(), FuncVReg); 422 423 // Add OpFunctionParameter instructions 424 int i = 0; 425 for (const auto &Arg : F.args()) { 426 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 427 Register ArgReg = VRegs[i][0]; 428 MRI->setRegClass(ArgReg, GR->getRegClass(ArgTypeVRegs[i])); 429 MRI->setType(ArgReg, GR->getRegType(ArgTypeVRegs[i])); 430 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 431 .addDef(ArgReg) 432 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); 433 if (F.isDeclaration()) 434 GR->add(&Arg, &MIRBuilder.getMF(), ArgReg); 435 GR->addGlobalObject(&Arg, &MIRBuilder.getMF(), ArgReg); 436 i++; 437 } 438 // Name the function. 439 if (F.hasName()) 440 buildOpName(FuncVReg, F.getName(), MIRBuilder); 441 442 // Handle entry points and function linkage. 443 if (isEntryPoint(F)) { 444 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 445 .addImm(static_cast<uint32_t>(getExecutionModel(*ST, F))) 446 .addUse(FuncVReg); 447 addStringImm(F.getName(), MIB); 448 } else if (F.getLinkage() != GlobalValue::InternalLinkage && 449 F.getLinkage() != GlobalValue::PrivateLinkage) { 450 SPIRV::LinkageType::LinkageType LnkTy = 451 F.isDeclaration() 452 ? SPIRV::LinkageType::Import 453 : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage && 454 ST->canUseExtension( 455 SPIRV::Extension::SPV_KHR_linkonce_odr) 456 ? SPIRV::LinkageType::LinkOnceODR 457 : SPIRV::LinkageType::Export); 458 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 459 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 460 } 461 462 // Handle function pointers decoration 463 bool hasFunctionPointers = 464 ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers); 465 if (hasFunctionPointers) { 466 if (F.hasFnAttribute("referenced-indirectly")) { 467 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) && 468 "Unexpected 'referenced-indirectly' attribute of the kernel " 469 "function"); 470 buildOpDecorate(FuncVReg, MIRBuilder, 471 SPIRV::Decoration::ReferencedIndirectlyINTEL, {}); 472 } 473 } 474 475 return true; 476 } 477 478 // Used to postpone producing of indirect function pointer types after all 479 // indirect calls info is collected 480 // TODO: 481 // - add a topological sort of IndirectCalls to ensure the best types knowledge 482 // - we may need to fix function formal parameter types if they are opaque 483 // pointers used as function pointers in these indirect calls 484 void SPIRVCallLowering::produceIndirectPtrTypes( 485 MachineIRBuilder &MIRBuilder) const { 486 // Create indirect call data types if any 487 MachineFunction &MF = MIRBuilder.getMF(); 488 for (auto const &IC : IndirectCalls) { 489 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder); 490 SmallVector<SPIRVType *, 4> SpirvArgTypes; 491 for (size_t i = 0; i < IC.ArgTys.size(); ++i) { 492 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder); 493 SpirvArgTypes.push_back(SPIRVTy); 494 if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i])) 495 GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF); 496 } 497 // SPIR-V function type: 498 FunctionType *FTy = 499 FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false); 500 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs( 501 FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder); 502 // SPIR-V pointer to function type: 503 SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType( 504 SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function); 505 // Correct the Callee type 506 GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF); 507 } 508 } 509 510 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 511 CallLoweringInfo &Info) const { 512 // Currently call returns should have single vregs. 513 // TODO: handle the case of multiple registers. 514 if (Info.OrigRet.Regs.size() > 1) 515 return false; 516 MachineFunction &MF = MIRBuilder.getMF(); 517 GR->setCurrentFunc(MF); 518 const Function *CF = nullptr; 519 std::string DemangledName; 520 const Type *OrigRetTy = Info.OrigRet.Ty; 521 522 // Emit a regular OpFunctionCall. If it's an externally declared function, 523 // be sure to emit its type and function declaration here. It will be hoisted 524 // globally later. 525 if (Info.Callee.isGlobal()) { 526 std::string FuncName = Info.Callee.getGlobal()->getName().str(); 527 DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); 528 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 529 // TODO: support constexpr casts and indirect calls. 530 if (CF == nullptr) 531 return false; 532 if (FunctionType *FTy = getOriginalFunctionType(*CF)) { 533 OrigRetTy = FTy->getReturnType(); 534 if (isUntypedPointerTy(OrigRetTy)) { 535 if (auto *DerivedRetTy = GR->findReturnType(CF)) 536 OrigRetTy = DerivedRetTy; 537 } 538 } 539 } 540 541 MachineRegisterInfo *MRI = MIRBuilder.getMRI(); 542 Register ResVReg = 543 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 544 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget()); 545 546 bool isFunctionDecl = CF && CF->isDeclaration(); 547 if (isFunctionDecl && !DemangledName.empty()) { 548 if (ResVReg.isValid()) { 549 if (!GR->getSPIRVTypeForVReg(ResVReg)) { 550 const Type *RetTy = OrigRetTy; 551 if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) { 552 const Value *OrigValue = Info.OrigRet.OrigValue; 553 if (!OrigValue) 554 OrigValue = Info.CB; 555 if (OrigValue) 556 if (Type *ElemTy = GR->findDeducedElementType(OrigValue)) 557 RetTy = 558 TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace()); 559 } 560 setRegClassType(ResVReg, RetTy, GR, MIRBuilder); 561 } 562 } else { 563 ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder); 564 } 565 SmallVector<Register, 8> ArgVRegs; 566 for (auto Arg : Info.OrigArgs) { 567 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 568 Register ArgReg = Arg.Regs[0]; 569 ArgVRegs.push_back(ArgReg); 570 SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg); 571 if (!SpvType) { 572 Type *ArgTy = nullptr; 573 if (auto *PtrArgTy = dyn_cast<PointerType>(Arg.Ty)) { 574 // If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we 575 // don't have access to original value in LLVM IR or info about 576 // deduced pointee type, then we should wait with setting the type for 577 // the virtual register until pre-legalizer step when we access 578 // @llvm.spv.assign.ptr.type.p...(...)'s info. 579 if (Arg.OrigValue) 580 if (Type *ElemTy = GR->findDeducedElementType(Arg.OrigValue)) 581 ArgTy = 582 TypedPointerType::get(ElemTy, PtrArgTy->getAddressSpace()); 583 } else { 584 ArgTy = Arg.Ty; 585 } 586 if (ArgTy) { 587 SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder); 588 GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF); 589 } 590 } 591 if (!MRI->getRegClassOrNull(ArgReg)) { 592 // Either we have SpvType created, or Arg.Ty is an untyped pointer and 593 // we know its virtual register's class and type even if we don't know 594 // pointee type. 595 MRI->setRegClass(ArgReg, SpvType ? GR->getRegClass(SpvType) 596 : &SPIRV::pIDRegClass); 597 MRI->setType( 598 ArgReg, 599 SpvType ? GR->getRegType(SpvType) 600 : LLT::pointer(cast<PointerType>(Arg.Ty)->getAddressSpace(), 601 GR->getPointerSize())); 602 } 603 } 604 if (auto Res = 605 SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(), 606 MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR)) 607 return *Res; 608 } 609 610 if (isFunctionDecl && !GR->find(CF, &MF).isValid()) { 611 // Emit the type info and forward function declaration to the first MBB 612 // to ensure VReg definition dependencies are valid across all MBBs. 613 MachineIRBuilder FirstBlockBuilder; 614 FirstBlockBuilder.setMF(MF); 615 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); 616 617 SmallVector<ArrayRef<Register>, 8> VRegArgs; 618 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 619 for (const Argument &Arg : CF->args()) { 620 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 621 continue; // Don't handle zero sized types. 622 Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(64)); 623 MRI->setRegClass(Reg, &SPIRV::iIDRegClass); 624 ToInsert.push_back({Reg}); 625 VRegArgs.push_back(ToInsert.back()); 626 } 627 // TODO: Reuse FunctionLoweringInfo 628 FunctionLoweringInfo FuncInfo; 629 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); 630 } 631 632 // Ignore the call if it's called from the internal service function 633 if (MIRBuilder.getMF() 634 .getFunction() 635 .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) 636 .isValid()) { 637 // insert a no-op 638 MIRBuilder.buildTrap(); 639 return true; 640 } 641 642 unsigned CallOp; 643 if (Info.CB->isIndirectCall()) { 644 if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) 645 report_fatal_error("An indirect call is encountered but SPIR-V without " 646 "extensions does not support it", 647 false); 648 // Set instruction operation according to SPV_INTEL_function_pointers 649 CallOp = SPIRV::OpFunctionPointerCallINTEL; 650 // Collect information about the indirect call to support possible 651 // specification of opaque ptr types of parent function's parameters 652 Register CalleeReg = Info.Callee.getReg(); 653 if (CalleeReg.isValid()) { 654 SPIRVCallLowering::SPIRVIndirectCall IndirectCall; 655 IndirectCall.Callee = CalleeReg; 656 IndirectCall.RetTy = OrigRetTy; 657 for (const auto &Arg : Info.OrigArgs) { 658 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); 659 IndirectCall.ArgTys.push_back(Arg.Ty); 660 IndirectCall.ArgRegs.push_back(Arg.Regs[0]); 661 } 662 IndirectCalls.push_back(IndirectCall); 663 } 664 } else { 665 // Emit a regular OpFunctionCall 666 CallOp = SPIRV::OpFunctionCall; 667 } 668 669 // Make sure there's a valid return reg, even for functions returning void. 670 if (!ResVReg.isValid()) 671 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); 672 SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder); 673 674 // Emit the call instruction and its args. 675 auto MIB = MIRBuilder.buildInstr(CallOp) 676 .addDef(ResVReg) 677 .addUse(GR->getSPIRVTypeID(RetType)) 678 .add(Info.Callee); 679 680 for (const auto &Arg : Info.OrigArgs) { 681 // Currently call args should have single vregs. 682 if (Arg.Regs.size() > 1) 683 return false; 684 MIB.addUse(Arg.Regs[0]); 685 } 686 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), 687 *ST->getRegBankInfo()); 688 } 689