1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVBuiltins.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 #include "llvm/ADT/APInt.h" 24 #include "llvm/IR/Constants.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/Support/Casting.h" 27 #include <cassert> 28 #include <functional> 29 30 using namespace llvm; 31 32 inline unsigned typeToAddressSpace(const Type *Ty) { 33 if (auto PType = dyn_cast<TypedPointerType>(Ty)) 34 return PType->getAddressSpace(); 35 if (auto PType = dyn_cast<PointerType>(Ty)) 36 return PType->getAddressSpace(); 37 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty); 38 ExtTy && isTypedPointerWrapper(ExtTy)) 39 return ExtTy->getIntParameter(0); 40 report_fatal_error("Unable to convert LLVM type to SPIRVType", true); 41 } 42 43 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 44 : PointerSize(PointerSize), Bound(0) {} 45 46 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 47 Register VReg, 48 MachineInstr &I, 49 const SPIRVInstrInfo &TII) { 50 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 51 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 52 return SpirvType; 53 } 54 55 SPIRVType * 56 SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg, 57 MachineInstr &I, 58 const SPIRVInstrInfo &TII) { 59 SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII); 60 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 61 return SpirvType; 62 } 63 64 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 65 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 66 const SPIRVInstrInfo &TII) { 67 SPIRVType *SpirvType = 68 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 69 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 70 return SpirvType; 71 } 72 73 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 74 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 75 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 76 SPIRVType *SpirvType = 77 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 78 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 79 return SpirvType; 80 } 81 82 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 83 Register VReg, 84 const MachineFunction &MF) { 85 VRegToTypeMap[&MF][VReg] = SpirvType; 86 } 87 88 static Register createTypeVReg(MachineRegisterInfo &MRI) { 89 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64)); 90 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 91 return Res; 92 } 93 94 inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 95 return createTypeVReg(MIRBuilder.getMF().getRegInfo()); 96 } 97 98 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 99 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 100 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 101 .addDef(createTypeVReg(MIRBuilder)); 102 }); 103 } 104 105 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const { 106 if (Width > 64) 107 report_fatal_error("Unsupported integer width!"); 108 const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget()); 109 if (ST.canUseExtension( 110 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) 111 return Width; 112 if (Width <= 8) 113 Width = 8; 114 else if (Width <= 16) 115 Width = 16; 116 else if (Width <= 32) 117 Width = 32; 118 else 119 Width = 64; 120 return Width; 121 } 122 123 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width, 124 MachineIRBuilder &MIRBuilder, 125 bool IsSigned) { 126 Width = adjustOpTypeIntWidth(Width); 127 const SPIRVSubtarget &ST = 128 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 129 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 130 if (ST.canUseExtension( 131 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) { 132 MIRBuilder.buildInstr(SPIRV::OpExtension) 133 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers); 134 MIRBuilder.buildInstr(SPIRV::OpCapability) 135 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL); 136 } 137 return MIRBuilder.buildInstr(SPIRV::OpTypeInt) 138 .addDef(createTypeVReg(MIRBuilder)) 139 .addImm(Width) 140 .addImm(IsSigned ? 1 : 0); 141 }); 142 } 143 144 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 145 MachineIRBuilder &MIRBuilder) { 146 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 147 return MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 148 .addDef(createTypeVReg(MIRBuilder)) 149 .addImm(Width); 150 }); 151 } 152 153 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 154 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 155 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 156 .addDef(createTypeVReg(MIRBuilder)); 157 }); 158 } 159 160 void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) { 161 // TODO: 162 // - take into account duplicate tracker case which is a known issue, 163 // - review other data structure wrt. possible issues related to removal 164 // of a machine instruction during instruction selection. 165 const MachineFunction *MF = MI->getParent()->getParent(); 166 auto It = LastInsertedTypeMap.find(MF); 167 if (It == LastInsertedTypeMap.end()) 168 return; 169 if (It->second == MI) 170 LastInsertedTypeMap.erase(MF); 171 } 172 173 SPIRVType *SPIRVGlobalRegistry::createOpType( 174 MachineIRBuilder &MIRBuilder, 175 std::function<MachineInstr *(MachineIRBuilder &)> Op) { 176 auto oldInsertPoint = MIRBuilder.getInsertPt(); 177 MachineBasicBlock *OldMBB = &MIRBuilder.getMBB(); 178 MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin(); 179 180 auto LastInsertedType = LastInsertedTypeMap.find(CurMF); 181 if (LastInsertedType != LastInsertedTypeMap.end()) { 182 auto It = LastInsertedType->second->getIterator(); 183 // It might happen that this instruction was removed from the first MBB, 184 // hence the Parent's check. 185 MachineBasicBlock::iterator InsertAt; 186 if (It->getParent() != NewMBB) 187 InsertAt = oldInsertPoint->getParent() == NewMBB 188 ? oldInsertPoint 189 : getInsertPtValidEnd(NewMBB); 190 else if (It->getNextNode()) 191 InsertAt = It->getNextNode()->getIterator(); 192 else 193 InsertAt = getInsertPtValidEnd(NewMBB); 194 MIRBuilder.setInsertPt(*NewMBB, InsertAt); 195 } else { 196 MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin()); 197 auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr); 198 assert(Result.second); 199 LastInsertedType = Result.first; 200 } 201 202 MachineInstr *Type = Op(MIRBuilder); 203 // We expect all users of this function to insert definitions at the insertion 204 // point set above that is always the first MBB. 205 assert(Type->getParent() == NewMBB); 206 LastInsertedType->second = Type; 207 208 MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint); 209 return Type; 210 } 211 212 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 213 SPIRVType *ElemType, 214 MachineIRBuilder &MIRBuilder) { 215 auto EleOpc = ElemType->getOpcode(); 216 (void)EleOpc; 217 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 218 EleOpc == SPIRV::OpTypeBool) && 219 "Invalid vector element type"); 220 221 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 222 return MIRBuilder.buildInstr(SPIRV::OpTypeVector) 223 .addDef(createTypeVReg(MIRBuilder)) 224 .addUse(getSPIRVTypeID(ElemType)) 225 .addImm(NumElems); 226 }); 227 } 228 229 std::tuple<Register, ConstantInt *, bool, unsigned> 230 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 231 MachineIRBuilder *MIRBuilder, 232 MachineInstr *I, 233 const SPIRVInstrInfo *TII) { 234 assert(SpvType); 235 const IntegerType *LLVMIntTy = 236 cast<IntegerType>(getTypeForSPIRVType(SpvType)); 237 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType); 238 bool NewInstr = false; 239 // Find a constant in DT or build a new one. 240 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 241 Register Res = DT.find(CI, CurMF); 242 if (!Res.isValid()) { 243 Res = 244 CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 245 CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass); 246 if (MIRBuilder) 247 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 248 else 249 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 250 DT.add(CI, CurMF, Res); 251 NewInstr = true; 252 } 253 return std::make_tuple(Res, CI, NewInstr, BitWidth); 254 } 255 256 std::tuple<Register, ConstantFP *, bool, unsigned> 257 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType, 258 MachineIRBuilder *MIRBuilder, 259 MachineInstr *I, 260 const SPIRVInstrInfo *TII) { 261 assert(SpvType); 262 LLVMContext &Ctx = CurMF->getFunction().getContext(); 263 const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType); 264 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType); 265 bool NewInstr = false; 266 // Find a constant in DT or build a new one. 267 auto *const CI = ConstantFP::get(Ctx, Val); 268 Register Res = DT.find(CI, CurMF); 269 if (!Res.isValid()) { 270 Res = 271 CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 272 CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass); 273 if (MIRBuilder) 274 assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder); 275 else 276 assignFloatTypeToVReg(BitWidth, Res, *I, *TII); 277 DT.add(CI, CurMF, Res); 278 NewInstr = true; 279 } 280 return std::make_tuple(Res, CI, NewInstr, BitWidth); 281 } 282 283 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I, 284 SPIRVType *SpvType, 285 const SPIRVInstrInfo &TII, 286 bool ZeroAsNull) { 287 assert(SpvType); 288 ConstantFP *CI; 289 Register Res; 290 bool New; 291 unsigned BitWidth; 292 std::tie(Res, CI, New, BitWidth) = 293 getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII); 294 // If we have found Res register which is defined by the passed G_CONSTANT 295 // machine instruction, a new constant instruction should be created. 296 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 297 return Res; 298 MachineIRBuilder MIRBuilder(I); 299 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 300 MachineInstrBuilder MIB; 301 // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0) 302 if (Val.isPosZero() && ZeroAsNull) { 303 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 304 .addDef(Res) 305 .addUse(getSPIRVTypeID(SpvType)); 306 } else { 307 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF) 308 .addDef(Res) 309 .addUse(getSPIRVTypeID(SpvType)); 310 addNumImm( 311 APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()), 312 MIB); 313 } 314 const auto &ST = CurMF->getSubtarget(); 315 constrainSelectedInstRegOperands( 316 *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo()); 317 return MIB; 318 }); 319 return Res; 320 } 321 322 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 323 SPIRVType *SpvType, 324 const SPIRVInstrInfo &TII, 325 bool ZeroAsNull) { 326 assert(SpvType); 327 ConstantInt *CI; 328 Register Res; 329 bool New; 330 unsigned BitWidth; 331 std::tie(Res, CI, New, BitWidth) = 332 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 333 // If we have found Res register which is defined by the passed G_CONSTANT 334 // machine instruction, a new constant instruction should be created. 335 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 336 return Res; 337 338 MachineIRBuilder MIRBuilder(I); 339 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 340 MachineInstrBuilder MIB; 341 if (Val || !ZeroAsNull) { 342 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 343 .addDef(Res) 344 .addUse(getSPIRVTypeID(SpvType)); 345 addNumImm(APInt(BitWidth, Val), MIB); 346 } else { 347 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 348 .addDef(Res) 349 .addUse(getSPIRVTypeID(SpvType)); 350 } 351 const auto &ST = CurMF->getSubtarget(); 352 constrainSelectedInstRegOperands( 353 *MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo()); 354 return MIB; 355 }); 356 return Res; 357 } 358 359 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 360 MachineIRBuilder &MIRBuilder, 361 SPIRVType *SpvType, bool EmitIR, 362 bool ZeroAsNull) { 363 assert(SpvType); 364 auto &MF = MIRBuilder.getMF(); 365 const IntegerType *LLVMIntTy = 366 cast<IntegerType>(getTypeForSPIRVType(SpvType)); 367 // Find a constant in DT or build a new one. 368 const auto ConstInt = 369 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 370 Register Res = DT.find(ConstInt, &MF); 371 if (!Res.isValid()) { 372 unsigned BitWidth = getScalarOrVectorBitWidth(SpvType); 373 LLT LLTy = LLT::scalar(BitWidth); 374 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 375 MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass); 376 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 377 SPIRV::AccessQualifier::ReadWrite, EmitIR); 378 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 379 if (EmitIR) { 380 MIRBuilder.buildConstant(Res, *ConstInt); 381 } else { 382 Register SpvTypeReg = getSPIRVTypeID(SpvType); 383 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 384 MachineInstrBuilder MIB; 385 if (Val || !ZeroAsNull) { 386 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 387 .addDef(Res) 388 .addUse(SpvTypeReg); 389 addNumImm(APInt(BitWidth, Val), MIB); 390 } else { 391 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 392 .addDef(Res) 393 .addUse(SpvTypeReg); 394 } 395 const auto &Subtarget = CurMF->getSubtarget(); 396 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 397 *Subtarget.getRegisterInfo(), 398 *Subtarget.getRegBankInfo()); 399 return MIB; 400 }); 401 } 402 } 403 return Res; 404 } 405 406 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 407 MachineIRBuilder &MIRBuilder, 408 SPIRVType *SpvType) { 409 auto &MF = MIRBuilder.getMF(); 410 auto &Ctx = MF.getFunction().getContext(); 411 if (!SpvType) { 412 const Type *LLVMFPTy = Type::getFloatTy(Ctx); 413 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder); 414 } 415 // Find a constant in DT or build a new one. 416 const auto ConstFP = ConstantFP::get(Ctx, Val); 417 Register Res = DT.find(ConstFP, &MF); 418 if (!Res.isValid()) { 419 Res = MF.getRegInfo().createGenericVirtualRegister( 420 LLT::scalar(getScalarOrVectorBitWidth(SpvType))); 421 MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass); 422 assignSPIRVTypeToVReg(SpvType, Res, MF); 423 DT.add(ConstFP, &MF, Res); 424 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 425 MachineInstrBuilder MIB; 426 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF) 427 .addDef(Res) 428 .addUse(getSPIRVTypeID(SpvType)); 429 addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB); 430 return MIB; 431 }); 432 } 433 434 return Res; 435 } 436 437 Register SPIRVGlobalRegistry::getOrCreateBaseRegister( 438 Constant *Val, MachineInstr &I, SPIRVType *SpvType, 439 const SPIRVInstrInfo &TII, unsigned BitWidth, bool ZeroAsNull) { 440 SPIRVType *Type = SpvType; 441 if (SpvType->getOpcode() == SPIRV::OpTypeVector || 442 SpvType->getOpcode() == SPIRV::OpTypeArray) { 443 auto EleTypeReg = SpvType->getOperand(1).getReg(); 444 Type = getSPIRVTypeForVReg(EleTypeReg); 445 } 446 if (Type->getOpcode() == SPIRV::OpTypeFloat) { 447 SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII); 448 return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I, 449 SpvBaseType, TII, ZeroAsNull); 450 } 451 assert(Type->getOpcode() == SPIRV::OpTypeInt); 452 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 453 return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I, 454 SpvBaseType, TII, ZeroAsNull); 455 } 456 457 Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull( 458 Constant *Val, MachineInstr &I, SPIRVType *SpvType, 459 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, 460 unsigned ElemCnt, bool ZeroAsNull) { 461 // Find a constant vector or array in DT or build a new one. 462 Register Res = DT.find(CA, CurMF); 463 // If no values are attached, the composite is null constant. 464 bool IsNull = Val->isNullValue() && ZeroAsNull; 465 if (!Res.isValid()) { 466 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 467 // error on validation. 468 // TODO: can moved below once sorting of types/consts/defs is implemented. 469 Register SpvScalConst; 470 if (!IsNull) 471 SpvScalConst = 472 getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth, ZeroAsNull); 473 474 LLT LLTy = LLT::scalar(64); 475 Register SpvVecConst = 476 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 477 CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType)); 478 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 479 DT.add(CA, CurMF, SpvVecConst); 480 MachineIRBuilder MIRBuilder(I); 481 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 482 MachineInstrBuilder MIB; 483 if (!IsNull) { 484 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 485 .addDef(SpvVecConst) 486 .addUse(getSPIRVTypeID(SpvType)); 487 for (unsigned i = 0; i < ElemCnt; ++i) 488 MIB.addUse(SpvScalConst); 489 } else { 490 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 491 .addDef(SpvVecConst) 492 .addUse(getSPIRVTypeID(SpvType)); 493 } 494 const auto &Subtarget = CurMF->getSubtarget(); 495 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 496 *Subtarget.getRegisterInfo(), 497 *Subtarget.getRegBankInfo()); 498 return MIB; 499 }); 500 return SpvVecConst; 501 } 502 return Res; 503 } 504 505 Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val, 506 MachineInstr &I, 507 SPIRVType *SpvType, 508 const SPIRVInstrInfo &TII, 509 bool ZeroAsNull) { 510 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 511 assert(LLVMTy->isVectorTy()); 512 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 513 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 514 assert(LLVMBaseTy->isIntegerTy()); 515 auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val); 516 auto *ConstVec = 517 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); 518 unsigned BW = getScalarOrVectorBitWidth(SpvType); 519 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, 520 SpvType->getOperand(2).getImm(), 521 ZeroAsNull); 522 } 523 524 Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val, 525 MachineInstr &I, 526 SPIRVType *SpvType, 527 const SPIRVInstrInfo &TII, 528 bool ZeroAsNull) { 529 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 530 assert(LLVMTy->isVectorTy()); 531 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 532 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 533 assert(LLVMBaseTy->isFloatingPointTy()); 534 auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val); 535 auto *ConstVec = 536 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal); 537 unsigned BW = getScalarOrVectorBitWidth(SpvType); 538 return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW, 539 SpvType->getOperand(2).getImm(), 540 ZeroAsNull); 541 } 542 543 Register SPIRVGlobalRegistry::getOrCreateConstIntArray( 544 uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType, 545 const SPIRVInstrInfo &TII) { 546 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 547 assert(LLVMTy->isArrayTy()); 548 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 549 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 550 Constant *CI = ConstantInt::get(LLVMBaseTy, Val); 551 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 552 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 553 // The following is reasonably unique key that is better that [Val]. The naive 554 // alternative would be something along the lines of: 555 // SmallVector<Constant *> NumCI(Num, CI); 556 // Constant *UniqueKey = 557 // ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI); 558 // that would be a truly unique but dangerous key, because it could lead to 559 // the creation of constants of arbitrary length (that is, the parameter of 560 // memset) which were missing in the original module. 561 Constant *UniqueKey = ConstantStruct::getAnon( 562 {PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)), 563 ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)}); 564 return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW, 565 LLVMArrTy->getNumElements()); 566 } 567 568 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 569 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, 570 Constant *CA, unsigned BitWidth, unsigned ElemCnt) { 571 Register Res = DT.find(CA, CurMF); 572 if (!Res.isValid()) { 573 Register SpvScalConst; 574 if (Val || EmitIR) { 575 SPIRVType *SpvBaseType = 576 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 577 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); 578 } 579 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64); 580 Register SpvVecConst = 581 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 582 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass); 583 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 584 DT.add(CA, CurMF, SpvVecConst); 585 if (EmitIR) { 586 MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst); 587 } else { 588 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 589 if (Val) { 590 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 591 .addDef(SpvVecConst) 592 .addUse(getSPIRVTypeID(SpvType)); 593 for (unsigned i = 0; i < ElemCnt; ++i) 594 MIB.addUse(SpvScalConst); 595 return MIB; 596 } else { 597 return MIRBuilder.buildInstr(SPIRV::OpConstantNull) 598 .addDef(SpvVecConst) 599 .addUse(getSPIRVTypeID(SpvType)); 600 } 601 }); 602 } 603 return SpvVecConst; 604 } 605 return Res; 606 } 607 608 Register 609 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, 610 MachineIRBuilder &MIRBuilder, 611 SPIRVType *SpvType, bool EmitIR) { 612 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 613 assert(LLVMTy->isVectorTy()); 614 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 615 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 616 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 617 auto ConstVec = 618 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 619 unsigned BW = getScalarOrVectorBitWidth(SpvType); 620 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 621 ConstVec, BW, 622 SpvType->getOperand(2).getImm()); 623 } 624 625 Register 626 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, 627 SPIRVType *SpvType) { 628 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 629 unsigned AddressSpace = typeToAddressSpace(LLVMTy); 630 // Find a constant in DT or build a new one. 631 Constant *CP = ConstantPointerNull::get( 632 PointerType::get(LLVMTy->getContext(), AddressSpace)); 633 Register Res = DT.find(CP, CurMF); 634 if (!Res.isValid()) { 635 LLT LLTy = LLT::pointer(AddressSpace, PointerSize); 636 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 637 CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass); 638 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 639 createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 640 return MIRBuilder.buildInstr(SPIRV::OpConstantNull) 641 .addDef(Res) 642 .addUse(getSPIRVTypeID(SpvType)); 643 }); 644 DT.add(CP, CurMF, Res); 645 } 646 return Res; 647 } 648 649 Register SPIRVGlobalRegistry::buildConstantSampler( 650 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, 651 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { 652 SPIRVType *SampTy; 653 if (SpvType) 654 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); 655 else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", 656 MIRBuilder)) == nullptr) 657 report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t"); 658 659 auto Sampler = 660 ResReg.isValid() 661 ? ResReg 662 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); 663 auto Res = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 664 return MIRBuilder.buildInstr(SPIRV::OpConstantSampler) 665 .addDef(Sampler) 666 .addUse(getSPIRVTypeID(SampTy)) 667 .addImm(AddrMode) 668 .addImm(Param) 669 .addImm(FilerMode); 670 }); 671 assert(Res->getOperand(0).isReg()); 672 return Res->getOperand(0).getReg(); 673 } 674 675 Register SPIRVGlobalRegistry::buildGlobalVariable( 676 Register ResVReg, SPIRVType *BaseType, StringRef Name, 677 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, 678 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 679 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 680 bool IsInstSelector) { 681 const GlobalVariable *GVar = nullptr; 682 if (GV) 683 GVar = cast<const GlobalVariable>(GV); 684 else { 685 // If GV is not passed explicitly, use the name to find or construct 686 // the global variable. 687 Module *M = MIRBuilder.getMF().getFunction().getParent(); 688 GVar = M->getGlobalVariable(Name); 689 if (GVar == nullptr) { 690 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 691 // Module takes ownership of the global var. 692 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 693 GlobalValue::ExternalLinkage, nullptr, 694 Twine(Name)); 695 } 696 GV = GVar; 697 } 698 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 699 if (Reg.isValid()) { 700 if (Reg != ResVReg) 701 MIRBuilder.buildCopy(ResVReg, Reg); 702 return ResVReg; 703 } 704 705 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 706 .addDef(ResVReg) 707 .addUse(getSPIRVTypeID(BaseType)) 708 .addImm(static_cast<uint32_t>(Storage)); 709 710 if (Init != 0) { 711 MIB.addUse(Init->getOperand(0).getReg()); 712 } 713 714 // ISel may introduce a new register on this step, so we need to add it to 715 // DT and correct its type avoiding fails on the next stage. 716 if (IsInstSelector) { 717 const auto &Subtarget = CurMF->getSubtarget(); 718 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 719 *Subtarget.getRegisterInfo(), 720 *Subtarget.getRegBankInfo()); 721 } 722 Reg = MIB->getOperand(0).getReg(); 723 DT.add(GVar, &MIRBuilder.getMF(), Reg); 724 addGlobalObject(GVar, &MIRBuilder.getMF(), Reg); 725 726 // Set to Reg the same type as ResVReg has. 727 auto MRI = MIRBuilder.getMRI(); 728 if (Reg != ResVReg) { 729 LLT RegLLTy = 730 LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize()); 731 MRI->setType(Reg, RegLLTy); 732 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 733 } else { 734 // Our knowledge about the type may be updated. 735 // If that's the case, we need to update a type 736 // associated with the register. 737 SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg); 738 if (!DefType || DefType != BaseType) 739 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 740 } 741 742 // If it's a global variable with name, output OpName for it. 743 if (GVar && GVar->hasName()) 744 buildOpName(Reg, GVar->getName(), MIRBuilder); 745 746 // Output decorations for the GV. 747 // TODO: maybe move to GenerateDecorations pass. 748 const SPIRVSubtarget &ST = 749 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 750 if (IsConst && ST.isOpenCLEnv()) 751 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 752 753 if (GVar && GVar->getAlign().valueOrOne().value() != 1) { 754 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); 755 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); 756 } 757 758 if (HasLinkageTy) 759 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 760 {static_cast<uint32_t>(LinkageType)}, Name); 761 762 SPIRV::BuiltIn::BuiltIn BuiltInId; 763 if (getSpirvBuiltInIdByName(Name, BuiltInId)) 764 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, 765 {static_cast<uint32_t>(BuiltInId)}); 766 767 // If it's a global variable with "spirv.Decorations" metadata node 768 // recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations" 769 // arguments. 770 MDNode *GVarMD = nullptr; 771 if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr) 772 buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD); 773 774 return Reg; 775 } 776 777 static std::string GetSpirvImageTypeName(const SPIRVType *Type, 778 MachineIRBuilder &MIRBuilder, 779 const std::string &Prefix); 780 781 static std::string buildSpirvTypeName(const SPIRVType *Type, 782 MachineIRBuilder &MIRBuilder) { 783 switch (Type->getOpcode()) { 784 case SPIRV::OpTypeSampledImage: { 785 return GetSpirvImageTypeName(Type, MIRBuilder, "sampled_image_"); 786 } 787 case SPIRV::OpTypeImage: { 788 return GetSpirvImageTypeName(Type, MIRBuilder, "image_"); 789 } 790 case SPIRV::OpTypeArray: { 791 MachineRegisterInfo *MRI = MIRBuilder.getMRI(); 792 Register ElementTypeReg = Type->getOperand(1).getReg(); 793 auto *ElementType = MRI->getUniqueVRegDef(ElementTypeReg); 794 const SPIRVType *TypeInst = MRI->getVRegDef(Type->getOperand(2).getReg()); 795 assert(TypeInst->getOpcode() != SPIRV::OpConstantI); 796 MachineInstr *ImmInst = MRI->getVRegDef(TypeInst->getOperand(1).getReg()); 797 assert(ImmInst->getOpcode() == TargetOpcode::G_CONSTANT); 798 uint32_t ArraySize = ImmInst->getOperand(1).getCImm()->getZExtValue(); 799 return (buildSpirvTypeName(ElementType, MIRBuilder) + Twine("[") + 800 Twine(ArraySize) + Twine("]")) 801 .str(); 802 } 803 case SPIRV::OpTypeFloat: 804 return ("f" + Twine(Type->getOperand(1).getImm())).str(); 805 case SPIRV::OpTypeSampler: 806 return ("sampler"); 807 case SPIRV::OpTypeInt: 808 if (Type->getOperand(2).getImm()) 809 return ("i" + Twine(Type->getOperand(1).getImm())).str(); 810 return ("u" + Twine(Type->getOperand(1).getImm())).str(); 811 default: 812 llvm_unreachable("Trying to the the name of an unknown type."); 813 } 814 } 815 816 static std::string GetSpirvImageTypeName(const SPIRVType *Type, 817 MachineIRBuilder &MIRBuilder, 818 const std::string &Prefix) { 819 Register SampledTypeReg = Type->getOperand(1).getReg(); 820 auto *SampledType = MIRBuilder.getMRI()->getUniqueVRegDef(SampledTypeReg); 821 std::string TypeName = Prefix + buildSpirvTypeName(SampledType, MIRBuilder); 822 for (uint32_t I = 2; I < Type->getNumOperands(); ++I) { 823 TypeName = (TypeName + '_' + Twine(Type->getOperand(I).getImm())).str(); 824 } 825 return TypeName; 826 } 827 828 Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding( 829 const SPIRVType *VarType, uint32_t Set, uint32_t Binding, 830 MachineIRBuilder &MIRBuilder) { 831 SPIRVType *VarPointerTypeReg = getOrCreateSPIRVPointerType( 832 VarType, MIRBuilder, SPIRV::StorageClass::UniformConstant); 833 Register VarReg = 834 MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); 835 836 // TODO: The name should come from the llvm-ir, but how that name will be 837 // passed from the HLSL to the backend has not been decided. Using this place 838 // holder for now. 839 std::string Name = ("__resource_" + buildSpirvTypeName(VarType, MIRBuilder) + 840 "_" + Twine(Set) + "_" + Twine(Binding)) 841 .str(); 842 buildGlobalVariable(VarReg, VarPointerTypeReg, Name, nullptr, 843 SPIRV::StorageClass::UniformConstant, nullptr, false, 844 false, SPIRV::LinkageType::Import, MIRBuilder, false); 845 846 buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set}); 847 buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding}); 848 return VarReg; 849 } 850 851 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 852 SPIRVType *ElemType, 853 MachineIRBuilder &MIRBuilder, 854 bool EmitIR) { 855 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 856 "Invalid array element type"); 857 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder); 858 Register NumElementsVReg = 859 buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR); 860 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 861 return MIRBuilder.buildInstr(SPIRV::OpTypeArray) 862 .addDef(createTypeVReg(MIRBuilder)) 863 .addUse(getSPIRVTypeID(ElemType)) 864 .addUse(NumElementsVReg); 865 }); 866 } 867 868 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 869 MachineIRBuilder &MIRBuilder) { 870 assert(Ty->hasName()); 871 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 872 Register ResVReg = createTypeVReg(MIRBuilder); 873 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 874 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 875 addStringImm(Name, MIB); 876 buildOpName(ResVReg, Name, MIRBuilder); 877 return MIB; 878 }); 879 } 880 881 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 882 MachineIRBuilder &MIRBuilder, 883 bool EmitIR) { 884 SmallVector<Register, 4> FieldTypes; 885 for (const auto &Elem : Ty->elements()) { 886 SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder); 887 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 888 "Invalid struct element type"); 889 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 890 } 891 Register ResVReg = createTypeVReg(MIRBuilder); 892 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 893 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 894 for (const auto &Ty : FieldTypes) 895 MIB.addUse(Ty); 896 if (Ty->hasName()) 897 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 898 if (Ty->isPacked()) 899 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 900 return MIB; 901 }); 902 } 903 904 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( 905 const Type *Ty, MachineIRBuilder &MIRBuilder, 906 SPIRV::AccessQualifier::AccessQualifier AccQual) { 907 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type"); 908 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this); 909 } 910 911 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer( 912 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, 913 MachineIRBuilder &MIRBuilder, Register Reg) { 914 if (!Reg.isValid()) 915 Reg = createTypeVReg(MIRBuilder); 916 917 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 918 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 919 .addDef(Reg) 920 .addImm(static_cast<uint32_t>(SC)) 921 .addUse(getSPIRVTypeID(ElemType)); 922 }); 923 } 924 925 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer( 926 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) { 927 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 928 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 929 .addUse(createTypeVReg(MIRBuilder)) 930 .addImm(static_cast<uint32_t>(SC)); 931 }); 932 } 933 934 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 935 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 936 MachineIRBuilder &MIRBuilder) { 937 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 938 .addDef(createTypeVReg(MIRBuilder)) 939 .addUse(getSPIRVTypeID(RetType)); 940 for (const SPIRVType *ArgType : ArgTypes) 941 MIB.addUse(getSPIRVTypeID(ArgType)); 942 return MIB; 943 } 944 945 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 946 const Type *Ty, SPIRVType *RetType, 947 const SmallVectorImpl<SPIRVType *> &ArgTypes, 948 MachineIRBuilder &MIRBuilder) { 949 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 950 if (Reg.isValid()) 951 return getSPIRVTypeForVReg(Reg); 952 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 953 DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType)); 954 return finishCreatingSPIRVType(Ty, SpirvType); 955 } 956 957 SPIRVType *SPIRVGlobalRegistry::findSPIRVType( 958 const Type *Ty, MachineIRBuilder &MIRBuilder, 959 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 960 Ty = adjustIntTypeByWidth(Ty); 961 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 962 if (Reg.isValid()) 963 return getSPIRVTypeForVReg(Reg); 964 if (ForwardPointerTypes.contains(Ty)) 965 return ForwardPointerTypes[Ty]; 966 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 967 } 968 969 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 970 assert(SpirvType && "Attempting to get type id for nullptr type."); 971 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 972 return SpirvType->uses().begin()->getReg(); 973 return SpirvType->defs().begin()->getReg(); 974 } 975 976 // We need to use a new LLVM integer type if there is a mismatch between 977 // number of bits in LLVM and SPIRV integer types to let DuplicateTracker 978 // ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without 979 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the 980 // same "OpTypeInt 8" type for a series of LLVM integer types with number of 981 // bits less than 8. This would lead to duplicate type definitions 982 // eventually due to the method that DuplicateTracker utilizes to reason 983 // about uniqueness of type records. 984 const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const { 985 if (auto IType = dyn_cast<IntegerType>(Ty)) { 986 unsigned SrcBitWidth = IType->getBitWidth(); 987 if (SrcBitWidth > 1) { 988 unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth); 989 // Maybe change source LLVM type to keep DuplicateTracker consistent. 990 if (SrcBitWidth != BitWidth) 991 Ty = IntegerType::get(Ty->getContext(), BitWidth); 992 } 993 } 994 return Ty; 995 } 996 997 SPIRVType *SPIRVGlobalRegistry::createSPIRVType( 998 const Type *Ty, MachineIRBuilder &MIRBuilder, 999 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 1000 if (isSpecialOpaqueType(Ty)) 1001 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); 1002 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 1003 auto t = TypeToSPIRVTypeMap.find(Ty); 1004 if (t != TypeToSPIRVTypeMap.end()) { 1005 auto tt = t->second.find(&MIRBuilder.getMF()); 1006 if (tt != t->second.end()) 1007 return getSPIRVTypeForVReg(tt->second); 1008 } 1009 1010 if (auto IType = dyn_cast<IntegerType>(Ty)) { 1011 const unsigned Width = IType->getBitWidth(); 1012 return Width == 1 ? getOpTypeBool(MIRBuilder) 1013 : getOpTypeInt(Width, MIRBuilder, false); 1014 } 1015 if (Ty->isFloatingPointTy()) 1016 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 1017 if (Ty->isVoidTy()) 1018 return getOpTypeVoid(MIRBuilder); 1019 if (Ty->isVectorTy()) { 1020 SPIRVType *El = 1021 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 1022 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 1023 MIRBuilder); 1024 } 1025 if (Ty->isArrayTy()) { 1026 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 1027 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 1028 } 1029 if (auto SType = dyn_cast<StructType>(Ty)) { 1030 if (SType->isOpaque()) 1031 return getOpTypeOpaque(SType, MIRBuilder); 1032 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 1033 } 1034 if (auto FType = dyn_cast<FunctionType>(Ty)) { 1035 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 1036 SmallVector<SPIRVType *, 4> ParamTypes; 1037 for (const auto &t : FType->params()) { 1038 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 1039 } 1040 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 1041 } 1042 1043 unsigned AddrSpace = typeToAddressSpace(Ty); 1044 SPIRVType *SpvElementType = nullptr; 1045 if (Type *ElemTy = ::getPointeeType(Ty)) 1046 SpvElementType = getOrCreateSPIRVType(ElemTy, MIRBuilder, AccQual, EmitIR); 1047 else 1048 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 1049 1050 // Get access to information about available extensions 1051 const SPIRVSubtarget *ST = 1052 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget()); 1053 auto SC = addressSpaceToStorageClass(AddrSpace, *ST); 1054 // Null pointer means we have a loop in type definitions, make and 1055 // return corresponding OpTypeForwardPointer. 1056 if (SpvElementType == nullptr) { 1057 if (!ForwardPointerTypes.contains(Ty)) 1058 ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder); 1059 return ForwardPointerTypes[Ty]; 1060 } 1061 // If we have forward pointer associated with this type, use its register 1062 // operand to create OpTypePointer. 1063 if (ForwardPointerTypes.contains(Ty)) { 1064 Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]); 1065 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 1066 } 1067 1068 return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC); 1069 } 1070 1071 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 1072 const Type *Ty, MachineIRBuilder &MIRBuilder, 1073 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 1074 if (TypesInProcessing.count(Ty) && !isPointerTyOrWrapper(Ty)) 1075 return nullptr; 1076 TypesInProcessing.insert(Ty); 1077 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 1078 TypesInProcessing.erase(Ty); 1079 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 1080 SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty); 1081 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 1082 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 1083 // will be added later. For special types it is already added to DT. 1084 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 1085 !isSpecialOpaqueType(Ty)) { 1086 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty); 1087 ExtTy && isTypedPointerWrapper(ExtTy)) 1088 DT.add(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), 1089 &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 1090 else if (!isPointerTy(Ty)) 1091 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 1092 else if (isTypedPointerTy(Ty)) 1093 DT.add(cast<TypedPointerType>(Ty)->getElementType(), 1094 getPointerAddressSpace(Ty), &MIRBuilder.getMF(), 1095 getSPIRVTypeID(SpirvType)); 1096 else 1097 DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 1098 getPointerAddressSpace(Ty), &MIRBuilder.getMF(), 1099 getSPIRVTypeID(SpirvType)); 1100 } 1101 1102 return SpirvType; 1103 } 1104 1105 SPIRVType * 1106 SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg, 1107 const MachineFunction *MF) const { 1108 auto t = VRegToTypeMap.find(MF ? MF : CurMF); 1109 if (t != VRegToTypeMap.end()) { 1110 auto tt = t->second.find(VReg); 1111 if (tt != t->second.end()) 1112 return tt->second; 1113 } 1114 return nullptr; 1115 } 1116 1117 SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg, 1118 MachineFunction *MF) { 1119 if (!MF) 1120 MF = CurMF; 1121 MachineInstr *Instr = getVRegDef(MF->getRegInfo(), VReg); 1122 return getSPIRVTypeForVReg(Instr->getOperand(1).getReg(), MF); 1123 } 1124 1125 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 1126 const Type *Ty, MachineIRBuilder &MIRBuilder, 1127 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 1128 Register Reg; 1129 if (auto *ExtTy = dyn_cast<TargetExtType>(Ty); 1130 ExtTy && isTypedPointerWrapper(ExtTy)) { 1131 Reg = DT.find(ExtTy->getTypeParameter(0), ExtTy->getIntParameter(0), 1132 &MIRBuilder.getMF()); 1133 } else if (!isPointerTy(Ty)) { 1134 Ty = adjustIntTypeByWidth(Ty); 1135 Reg = DT.find(Ty, &MIRBuilder.getMF()); 1136 } else if (isTypedPointerTy(Ty)) { 1137 Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(), 1138 getPointerAddressSpace(Ty), &MIRBuilder.getMF()); 1139 } else { 1140 Reg = 1141 DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 1142 getPointerAddressSpace(Ty), &MIRBuilder.getMF()); 1143 } 1144 1145 if (Reg.isValid() && !isSpecialOpaqueType(Ty)) 1146 return getSPIRVTypeForVReg(Reg); 1147 TypesInProcessing.clear(); 1148 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 1149 // Create normal pointer types for the corresponding OpTypeForwardPointers. 1150 for (auto &CU : ForwardPointerTypes) { 1151 const Type *Ty2 = CU.first; 1152 SPIRVType *STy2 = CU.second; 1153 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 1154 STy2 = getSPIRVTypeForVReg(Reg); 1155 else 1156 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 1157 if (Ty == Ty2) 1158 STy = STy2; 1159 } 1160 ForwardPointerTypes.clear(); 1161 return STy; 1162 } 1163 1164 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 1165 unsigned TypeOpcode) const { 1166 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1167 assert(Type && "isScalarOfType VReg has no type assigned"); 1168 return Type->getOpcode() == TypeOpcode; 1169 } 1170 1171 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 1172 unsigned TypeOpcode) const { 1173 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1174 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 1175 if (Type->getOpcode() == TypeOpcode) 1176 return true; 1177 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1178 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 1179 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 1180 return ScalarType->getOpcode() == TypeOpcode; 1181 } 1182 return false; 1183 } 1184 1185 unsigned 1186 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const { 1187 return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg)); 1188 } 1189 1190 unsigned 1191 SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const { 1192 if (!Type) 1193 return 0; 1194 return Type->getOpcode() == SPIRV::OpTypeVector 1195 ? static_cast<unsigned>(Type->getOperand(2).getImm()) 1196 : 1; 1197 } 1198 1199 SPIRVType * 1200 SPIRVGlobalRegistry::getScalarOrVectorComponentType(Register VReg) const { 1201 return getScalarOrVectorComponentType(getSPIRVTypeForVReg(VReg)); 1202 } 1203 1204 SPIRVType * 1205 SPIRVGlobalRegistry::getScalarOrVectorComponentType(SPIRVType *Type) const { 1206 if (!Type) 1207 return nullptr; 1208 Register ScalarReg = Type->getOpcode() == SPIRV::OpTypeVector 1209 ? Type->getOperand(1).getReg() 1210 : Type->getOperand(0).getReg(); 1211 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarReg); 1212 assert(isScalarOrVectorOfType(Type->getOperand(0).getReg(), 1213 ScalarType->getOpcode())); 1214 return ScalarType; 1215 } 1216 1217 unsigned 1218 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 1219 assert(Type && "Invalid Type pointer"); 1220 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1221 auto EleTypeReg = Type->getOperand(1).getReg(); 1222 Type = getSPIRVTypeForVReg(EleTypeReg); 1223 } 1224 if (Type->getOpcode() == SPIRV::OpTypeInt || 1225 Type->getOpcode() == SPIRV::OpTypeFloat) 1226 return Type->getOperand(1).getImm(); 1227 if (Type->getOpcode() == SPIRV::OpTypeBool) 1228 return 1; 1229 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 1230 } 1231 1232 unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth( 1233 const SPIRVType *Type) const { 1234 assert(Type && "Invalid Type pointer"); 1235 unsigned NumElements = 1; 1236 if (Type->getOpcode() == SPIRV::OpTypeVector) { 1237 NumElements = static_cast<unsigned>(Type->getOperand(2).getImm()); 1238 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); 1239 } 1240 return Type->getOpcode() == SPIRV::OpTypeInt || 1241 Type->getOpcode() == SPIRV::OpTypeFloat 1242 ? NumElements * Type->getOperand(1).getImm() 1243 : 0; 1244 } 1245 1246 const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType( 1247 const SPIRVType *Type) const { 1248 if (Type && Type->getOpcode() == SPIRV::OpTypeVector) 1249 Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg()); 1250 return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr; 1251 } 1252 1253 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 1254 const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type); 1255 return IntType && IntType->getOperand(2).getImm() != 0; 1256 } 1257 1258 SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) { 1259 return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer 1260 ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) 1261 : nullptr; 1262 } 1263 1264 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) { 1265 SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg)); 1266 return ElemType ? ElemType->getOpcode() : 0; 1267 } 1268 1269 bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1, 1270 const SPIRVType *Type2) const { 1271 if (!Type1 || !Type2) 1272 return false; 1273 auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode(); 1274 // Ignore difference between <1.5 and >=1.5 protocol versions: 1275 // it's valid if either Result Type or Operand is a pointer, and the other 1276 // is a pointer, an integer scalar, or an integer vector. 1277 if (Op1 == SPIRV::OpTypePointer && 1278 (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2))) 1279 return true; 1280 if (Op2 == SPIRV::OpTypePointer && 1281 (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1))) 1282 return true; 1283 unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1), 1284 Bits2 = getNumScalarOrVectorTotalBitWidth(Type2); 1285 return Bits1 > 0 && Bits1 == Bits2; 1286 } 1287 1288 SPIRV::StorageClass::StorageClass 1289 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 1290 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 1291 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 1292 Type->getOperand(1).isImm() && "Pointer type is expected"); 1293 return getPointerStorageClass(Type); 1294 } 1295 1296 SPIRV::StorageClass::StorageClass 1297 SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const { 1298 return static_cast<SPIRV::StorageClass::StorageClass>( 1299 Type->getOperand(1).getImm()); 1300 } 1301 1302 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( 1303 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim, 1304 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, 1305 SPIRV::ImageFormat::ImageFormat ImageFormat, 1306 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 1307 auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim, 1308 Depth, Arrayed, Multisampled, Sampled, 1309 ImageFormat, AccessQual); 1310 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1311 return Res; 1312 Register ResVReg = createTypeVReg(MIRBuilder); 1313 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1314 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeImage) 1315 .addDef(ResVReg) 1316 .addUse(getSPIRVTypeID(SampledType)) 1317 .addImm(Dim) 1318 .addImm(Depth) // Depth (whether or not it is a Depth image). 1319 .addImm(Arrayed) // Arrayed. 1320 .addImm(Multisampled) // Multisampled (0 = only single-sample). 1321 .addImm(Sampled) // Sampled (0 = usage known at runtime). 1322 .addImm(ImageFormat); 1323 1324 if (AccessQual != SPIRV::AccessQualifier::None) 1325 MIB.addImm(AccessQual); 1326 return MIB; 1327 } 1328 1329 SPIRVType * 1330 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { 1331 auto TD = SPIRV::make_descr_sampler(); 1332 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1333 return Res; 1334 Register ResVReg = createTypeVReg(MIRBuilder); 1335 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1336 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg); 1337 } 1338 1339 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( 1340 MachineIRBuilder &MIRBuilder, 1341 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 1342 auto TD = SPIRV::make_descr_pipe(AccessQual); 1343 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1344 return Res; 1345 Register ResVReg = createTypeVReg(MIRBuilder); 1346 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1347 return MIRBuilder.buildInstr(SPIRV::OpTypePipe) 1348 .addDef(ResVReg) 1349 .addImm(AccessQual); 1350 } 1351 1352 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( 1353 MachineIRBuilder &MIRBuilder) { 1354 auto TD = SPIRV::make_descr_event(); 1355 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1356 return Res; 1357 Register ResVReg = createTypeVReg(MIRBuilder); 1358 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1359 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg); 1360 } 1361 1362 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( 1363 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { 1364 auto TD = SPIRV::make_descr_sampled_image( 1365 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( 1366 ImageType->getOperand(1).getReg())), 1367 ImageType); 1368 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 1369 return Res; 1370 Register ResVReg = createTypeVReg(MIRBuilder); 1371 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 1372 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) 1373 .addDef(ResVReg) 1374 .addUse(getSPIRVTypeID(ImageType)); 1375 } 1376 1377 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( 1378 MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType, 1379 const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns, 1380 uint32_t Use) { 1381 Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF()); 1382 if (ResVReg.isValid()) 1383 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 1384 ResVReg = createTypeVReg(MIRBuilder); 1385 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder); 1386 SPIRVType *SpirvTy = 1387 MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR) 1388 .addDef(ResVReg) 1389 .addUse(getSPIRVTypeID(ElemType)) 1390 .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true)) 1391 .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true)) 1392 .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true)) 1393 .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true)); 1394 DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg); 1395 return SpirvTy; 1396 } 1397 1398 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( 1399 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { 1400 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); 1401 if (ResVReg.isValid()) 1402 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 1403 ResVReg = createTypeVReg(MIRBuilder); 1404 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg); 1405 DT.add(Ty, &MIRBuilder.getMF(), ResVReg); 1406 return SpirvTy; 1407 } 1408 1409 const MachineInstr * 1410 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, 1411 MachineIRBuilder &MIRBuilder) { 1412 Register Reg = DT.find(TD, &MIRBuilder.getMF()); 1413 if (Reg.isValid()) 1414 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); 1415 return nullptr; 1416 } 1417 1418 // Returns nullptr if unable to recognize SPIRV type name 1419 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( 1420 StringRef TypeStr, MachineIRBuilder &MIRBuilder, 1421 SPIRV::StorageClass::StorageClass SC, 1422 SPIRV::AccessQualifier::AccessQualifier AQ) { 1423 unsigned VecElts = 0; 1424 auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); 1425 1426 // Parse strings representing either a SPIR-V or OpenCL builtin type. 1427 if (hasBuiltinTypePrefix(TypeStr)) 1428 return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType( 1429 TypeStr.str(), MIRBuilder.getContext()), 1430 MIRBuilder, AQ); 1431 1432 // Parse type name in either "typeN" or "type vector[N]" format, where 1433 // N is the number of elements of the vector. 1434 Type *Ty; 1435 1436 Ty = parseBasicTypeName(TypeStr, Ctx); 1437 if (!Ty) 1438 // Unable to recognize SPIRV type name 1439 return nullptr; 1440 1441 auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ); 1442 1443 // Handle "type*" or "type* vector[N]". 1444 if (TypeStr.starts_with("*")) { 1445 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1446 TypeStr = TypeStr.substr(strlen("*")); 1447 } 1448 1449 // Handle "typeN*" or "type vector[N]*". 1450 bool IsPtrToVec = TypeStr.consume_back("*"); 1451 1452 if (TypeStr.consume_front(" vector[")) { 1453 TypeStr = TypeStr.substr(0, TypeStr.find(']')); 1454 } 1455 TypeStr.getAsInteger(10, VecElts); 1456 if (VecElts > 0) 1457 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); 1458 1459 if (IsPtrToVec) 1460 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1461 1462 return SpirvTy; 1463 } 1464 1465 SPIRVType * 1466 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 1467 MachineIRBuilder &MIRBuilder) { 1468 return getOrCreateSPIRVType( 1469 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 1470 MIRBuilder); 1471 } 1472 1473 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 1474 SPIRVType *SpirvType) { 1475 assert(CurMF == SpirvType->getMF()); 1476 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 1477 SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy); 1478 return SpirvType; 1479 } 1480 1481 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth, 1482 MachineInstr &I, 1483 const SPIRVInstrInfo &TII, 1484 unsigned SPIRVOPcode, 1485 Type *LLVMTy) { 1486 Register Reg = DT.find(LLVMTy, CurMF); 1487 if (Reg.isValid()) 1488 return getSPIRVTypeForVReg(Reg); 1489 MachineBasicBlock &BB = *I.getParent(); 1490 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode)) 1491 .addDef(createTypeVReg(CurMF->getRegInfo())) 1492 .addImm(BitWidth) 1493 .addImm(0); 1494 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1495 return finishCreatingSPIRVType(LLVMTy, MIB); 1496 } 1497 1498 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 1499 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1500 // Maybe adjust bit width to keep DuplicateTracker consistent. Without 1501 // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for 1502 // example, the same "OpTypeInt 8" type for a series of LLVM integer types 1503 // with number of bits less than 8, causing duplicate type definitions. 1504 BitWidth = adjustOpTypeIntWidth(BitWidth); 1505 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 1506 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy); 1507 } 1508 1509 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType( 1510 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1511 LLVMContext &Ctx = CurMF->getFunction().getContext(); 1512 Type *LLVMTy; 1513 switch (BitWidth) { 1514 case 16: 1515 LLVMTy = Type::getHalfTy(Ctx); 1516 break; 1517 case 32: 1518 LLVMTy = Type::getFloatTy(Ctx); 1519 break; 1520 case 64: 1521 LLVMTy = Type::getDoubleTy(Ctx); 1522 break; 1523 default: 1524 llvm_unreachable("Bit width is of unexpected size."); 1525 } 1526 return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy); 1527 } 1528 1529 SPIRVType * 1530 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 1531 return getOrCreateSPIRVType( 1532 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 1533 MIRBuilder); 1534 } 1535 1536 SPIRVType * 1537 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 1538 const SPIRVInstrInfo &TII) { 1539 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 1540 Register Reg = DT.find(LLVMTy, CurMF); 1541 if (Reg.isValid()) 1542 return getSPIRVTypeForVReg(Reg); 1543 MachineBasicBlock &BB = *I.getParent(); 1544 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 1545 .addDef(createTypeVReg(CurMF->getRegInfo())); 1546 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1547 return finishCreatingSPIRVType(LLVMTy, MIB); 1548 } 1549 1550 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1551 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 1552 return getOrCreateSPIRVType( 1553 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1554 NumElements), 1555 MIRBuilder); 1556 } 1557 1558 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1559 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1560 const SPIRVInstrInfo &TII) { 1561 Type *LLVMTy = FixedVectorType::get( 1562 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1563 Register Reg = DT.find(LLVMTy, CurMF); 1564 if (Reg.isValid()) 1565 return getSPIRVTypeForVReg(Reg); 1566 MachineBasicBlock &BB = *I.getParent(); 1567 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 1568 .addDef(createTypeVReg(CurMF->getRegInfo())) 1569 .addUse(getSPIRVTypeID(BaseType)) 1570 .addImm(NumElements); 1571 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1572 return finishCreatingSPIRVType(LLVMTy, MIB); 1573 } 1574 1575 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( 1576 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1577 const SPIRVInstrInfo &TII) { 1578 Type *LLVMTy = ArrayType::get( 1579 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1580 Register Reg = DT.find(LLVMTy, CurMF); 1581 if (Reg.isValid()) 1582 return getSPIRVTypeForVReg(Reg); 1583 MachineBasicBlock &BB = *I.getParent(); 1584 SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII); 1585 Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII); 1586 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray)) 1587 .addDef(createTypeVReg(CurMF->getRegInfo())) 1588 .addUse(getSPIRVTypeID(BaseType)) 1589 .addUse(Len); 1590 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1591 return finishCreatingSPIRVType(LLVMTy, MIB); 1592 } 1593 1594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1595 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, 1596 SPIRV::StorageClass::StorageClass SC) { 1597 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1598 unsigned AddressSpace = storageClassToAddressSpace(SC); 1599 Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType), 1600 AddressSpace); 1601 // check if this type is already available 1602 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1603 if (Reg.isValid()) 1604 return getSPIRVTypeForVReg(Reg); 1605 // create a new type 1606 return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) { 1607 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(), 1608 MIRBuilder.getDebugLoc(), 1609 MIRBuilder.getTII().get(SPIRV::OpTypePointer)) 1610 .addDef(createTypeVReg(CurMF->getRegInfo())) 1611 .addImm(static_cast<uint32_t>(SC)) 1612 .addUse(getSPIRVTypeID(BaseType)); 1613 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1614 finishCreatingSPIRVType(LLVMTy, MIB); 1615 return MIB; 1616 }); 1617 } 1618 1619 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1620 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &, 1621 SPIRV::StorageClass::StorageClass SC) { 1622 MachineIRBuilder MIRBuilder(I); 1623 return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC); 1624 } 1625 1626 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 1627 SPIRVType *SpvType, 1628 const SPIRVInstrInfo &TII) { 1629 assert(SpvType); 1630 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 1631 assert(LLVMTy); 1632 // Find a constant in DT or build a new one. 1633 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 1634 Register Res = DT.find(UV, CurMF); 1635 if (Res.isValid()) 1636 return Res; 1637 LLT LLTy = LLT::scalar(64); 1638 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 1639 CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass); 1640 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 1641 DT.add(UV, CurMF, Res); 1642 1643 MachineInstrBuilder MIB; 1644 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 1645 .addDef(Res) 1646 .addUse(getSPIRVTypeID(SpvType)); 1647 const auto &ST = CurMF->getSubtarget(); 1648 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 1649 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 1650 return Res; 1651 } 1652 1653 const TargetRegisterClass * 1654 SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const { 1655 unsigned Opcode = SpvType->getOpcode(); 1656 switch (Opcode) { 1657 case SPIRV::OpTypeFloat: 1658 return &SPIRV::fIDRegClass; 1659 case SPIRV::OpTypePointer: 1660 return &SPIRV::pIDRegClass; 1661 case SPIRV::OpTypeVector: { 1662 SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 1663 unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0; 1664 if (ElemOpcode == SPIRV::OpTypeFloat) 1665 return &SPIRV::vfIDRegClass; 1666 if (ElemOpcode == SPIRV::OpTypePointer) 1667 return &SPIRV::vpIDRegClass; 1668 return &SPIRV::vIDRegClass; 1669 } 1670 } 1671 return &SPIRV::iIDRegClass; 1672 } 1673 1674 inline unsigned getAS(SPIRVType *SpvType) { 1675 return storageClassToAddressSpace( 1676 static_cast<SPIRV::StorageClass::StorageClass>( 1677 SpvType->getOperand(1).getImm())); 1678 } 1679 1680 LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const { 1681 unsigned Opcode = SpvType ? SpvType->getOpcode() : 0; 1682 switch (Opcode) { 1683 case SPIRV::OpTypeInt: 1684 case SPIRV::OpTypeFloat: 1685 case SPIRV::OpTypeBool: 1686 return LLT::scalar(getScalarOrVectorBitWidth(SpvType)); 1687 case SPIRV::OpTypePointer: 1688 return LLT::pointer(getAS(SpvType), getPointerSize()); 1689 case SPIRV::OpTypeVector: { 1690 SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 1691 LLT ET; 1692 switch (ElemType ? ElemType->getOpcode() : 0) { 1693 case SPIRV::OpTypePointer: 1694 ET = LLT::pointer(getAS(ElemType), getPointerSize()); 1695 break; 1696 case SPIRV::OpTypeInt: 1697 case SPIRV::OpTypeFloat: 1698 case SPIRV::OpTypeBool: 1699 ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType)); 1700 break; 1701 default: 1702 ET = LLT::scalar(64); 1703 } 1704 return LLT::fixed_vector( 1705 static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET); 1706 } 1707 } 1708 return LLT::scalar(64); 1709 } 1710