1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains class to help build DXIL op functions. 10 //===----------------------------------------------------------------------===// 11 12 #include "DXILOpBuilder.h" 13 #include "DXILConstants.h" 14 #include "llvm/IR/Module.h" 15 #include "llvm/Support/DXILABI.h" 16 #include "llvm/Support/ErrorHandling.h" 17 #include <optional> 18 19 using namespace llvm; 20 using namespace llvm::dxil; 21 22 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 23 24 namespace { 25 enum OverloadKind : uint16_t { 26 UNDEFINED = 0, 27 VOID = 1, 28 HALF = 1 << 1, 29 FLOAT = 1 << 2, 30 DOUBLE = 1 << 3, 31 I1 = 1 << 4, 32 I8 = 1 << 5, 33 I16 = 1 << 6, 34 I32 = 1 << 7, 35 I64 = 1 << 8, 36 UserDefineType = 1 << 9, 37 ObjectType = 1 << 10, 38 }; 39 struct Version { 40 unsigned Major = 0; 41 unsigned Minor = 0; 42 }; 43 44 struct OpOverload { 45 Version DXILVersion; 46 uint16_t ValidTys; 47 }; 48 } // namespace 49 50 struct OpStage { 51 Version DXILVersion; 52 uint32_t ValidStages; 53 }; 54 55 static const char *getOverloadTypeName(OverloadKind Kind) { 56 switch (Kind) { 57 case OverloadKind::HALF: 58 return "f16"; 59 case OverloadKind::FLOAT: 60 return "f32"; 61 case OverloadKind::DOUBLE: 62 return "f64"; 63 case OverloadKind::I1: 64 return "i1"; 65 case OverloadKind::I8: 66 return "i8"; 67 case OverloadKind::I16: 68 return "i16"; 69 case OverloadKind::I32: 70 return "i32"; 71 case OverloadKind::I64: 72 return "i64"; 73 case OverloadKind::VOID: 74 case OverloadKind::UNDEFINED: 75 return "void"; 76 case OverloadKind::ObjectType: 77 case OverloadKind::UserDefineType: 78 break; 79 } 80 llvm_unreachable("invalid overload type for name"); 81 } 82 83 static OverloadKind getOverloadKind(Type *Ty) { 84 if (!Ty) 85 return OverloadKind::VOID; 86 87 Type::TypeID T = Ty->getTypeID(); 88 switch (T) { 89 case Type::VoidTyID: 90 return OverloadKind::VOID; 91 case Type::HalfTyID: 92 return OverloadKind::HALF; 93 case Type::FloatTyID: 94 return OverloadKind::FLOAT; 95 case Type::DoubleTyID: 96 return OverloadKind::DOUBLE; 97 case Type::IntegerTyID: { 98 IntegerType *ITy = cast<IntegerType>(Ty); 99 unsigned Bits = ITy->getBitWidth(); 100 switch (Bits) { 101 case 1: 102 return OverloadKind::I1; 103 case 8: 104 return OverloadKind::I8; 105 case 16: 106 return OverloadKind::I16; 107 case 32: 108 return OverloadKind::I32; 109 case 64: 110 return OverloadKind::I64; 111 default: 112 llvm_unreachable("invalid overload type"); 113 return OverloadKind::VOID; 114 } 115 } 116 case Type::PointerTyID: 117 return OverloadKind::UserDefineType; 118 case Type::StructTyID: { 119 // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework 120 // how we're handling overloads and remove the `OverloadKind` proxy enum. 121 StructType *ST = cast<StructType>(Ty); 122 return getOverloadKind(ST->getElementType(0)); 123 } 124 default: 125 return OverloadKind::UNDEFINED; 126 } 127 } 128 129 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 130 if (Kind < OverloadKind::UserDefineType) { 131 return getOverloadTypeName(Kind); 132 } else if (Kind == OverloadKind::UserDefineType) { 133 StructType *ST = cast<StructType>(Ty); 134 return ST->getStructName().str(); 135 } else if (Kind == OverloadKind::ObjectType) { 136 StructType *ST = cast<StructType>(Ty); 137 return ST->getStructName().str(); 138 } else { 139 std::string Str; 140 raw_string_ostream OS(Str); 141 Ty->print(OS); 142 return OS.str(); 143 } 144 } 145 146 // Static properties. 147 struct OpCodeProperty { 148 dxil::OpCode OpCode; 149 // Offset in DXILOpCodeNameTable. 150 unsigned OpCodeNameOffset; 151 dxil::OpCodeClass OpCodeClass; 152 // Offset in DXILOpCodeClassNameTable. 153 unsigned OpCodeClassNameOffset; 154 llvm::SmallVector<OpOverload> Overloads; 155 llvm::SmallVector<OpStage> Stages; 156 int OverloadParamIndex; // parameter index which control the overload. 157 // When < 0, should be only 1 overload type. 158 }; 159 160 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and 161 // getOpCodeParameterKind which generated by tableGen. 162 #define DXIL_OP_OPERATION_TABLE 163 #include "DXILOperation.inc" 164 #undef DXIL_OP_OPERATION_TABLE 165 166 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 167 const OpCodeProperty &Prop) { 168 if (Kind == OverloadKind::VOID) { 169 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 170 } 171 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 172 getTypeName(Kind, Ty)) 173 .str(); 174 } 175 176 static std::string constructOverloadTypeName(OverloadKind Kind, 177 StringRef TypeName) { 178 if (Kind == OverloadKind::VOID) 179 return TypeName.str(); 180 181 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); 182 return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); 183 } 184 185 static StructType *getOrCreateStructType(StringRef Name, 186 ArrayRef<Type *> EltTys, 187 LLVMContext &Ctx) { 188 StructType *ST = StructType::getTypeByName(Ctx, Name); 189 if (ST) 190 return ST; 191 192 return StructType::create(Ctx, EltTys, Name); 193 } 194 195 static StructType *getResRetType(Type *ElementTy) { 196 LLVMContext &Ctx = ElementTy->getContext(); 197 OverloadKind Kind = getOverloadKind(ElementTy); 198 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); 199 Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy, 200 Type::getInt32Ty(Ctx)}; 201 return getOrCreateStructType(TypeName, FieldTypes, Ctx); 202 } 203 204 static StructType *getHandleType(LLVMContext &Ctx) { 205 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx), 206 Ctx); 207 } 208 209 static StructType *getResBindType(LLVMContext &Context) { 210 if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind")) 211 return ST; 212 Type *Int32Ty = Type::getInt32Ty(Context); 213 Type *Int8Ty = Type::getInt8Ty(Context); 214 return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty}, 215 "dx.types.ResBind"); 216 } 217 218 static StructType *getResPropsType(LLVMContext &Context) { 219 if (auto *ST = 220 StructType::getTypeByName(Context, "dx.types.ResourceProperties")) 221 return ST; 222 Type *Int32Ty = Type::getInt32Ty(Context); 223 return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties"); 224 } 225 226 static StructType *getSplitDoubleType(LLVMContext &Context) { 227 if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble")) 228 return ST; 229 Type *Int32Ty = Type::getInt32Ty(Context); 230 return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble"); 231 } 232 233 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, 234 Type *OverloadTy) { 235 switch (Kind) { 236 case OpParamType::VoidTy: 237 return Type::getVoidTy(Ctx); 238 case OpParamType::HalfTy: 239 return Type::getHalfTy(Ctx); 240 case OpParamType::FloatTy: 241 return Type::getFloatTy(Ctx); 242 case OpParamType::DoubleTy: 243 return Type::getDoubleTy(Ctx); 244 case OpParamType::Int1Ty: 245 return Type::getInt1Ty(Ctx); 246 case OpParamType::Int8Ty: 247 return Type::getInt8Ty(Ctx); 248 case OpParamType::Int16Ty: 249 return Type::getInt16Ty(Ctx); 250 case OpParamType::Int32Ty: 251 return Type::getInt32Ty(Ctx); 252 case OpParamType::Int64Ty: 253 return Type::getInt64Ty(Ctx); 254 case OpParamType::OverloadTy: 255 return OverloadTy; 256 case OpParamType::ResRetHalfTy: 257 return getResRetType(Type::getHalfTy(Ctx)); 258 case OpParamType::ResRetFloatTy: 259 return getResRetType(Type::getFloatTy(Ctx)); 260 case OpParamType::ResRetDoubleTy: 261 return getResRetType(Type::getDoubleTy(Ctx)); 262 case OpParamType::ResRetInt16Ty: 263 return getResRetType(Type::getInt16Ty(Ctx)); 264 case OpParamType::ResRetInt32Ty: 265 return getResRetType(Type::getInt32Ty(Ctx)); 266 case OpParamType::ResRetInt64Ty: 267 return getResRetType(Type::getInt64Ty(Ctx)); 268 case OpParamType::HandleTy: 269 return getHandleType(Ctx); 270 case OpParamType::ResBindTy: 271 return getResBindType(Ctx); 272 case OpParamType::ResPropsTy: 273 return getResPropsType(Ctx); 274 case OpParamType::SplitDoubleTy: 275 return getSplitDoubleType(Ctx); 276 } 277 llvm_unreachable("Invalid parameter kind"); 278 return nullptr; 279 } 280 281 static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) { 282 switch (EnvType) { 283 case Triple::Pixel: 284 return ShaderKind::pixel; 285 case Triple::Vertex: 286 return ShaderKind::vertex; 287 case Triple::Geometry: 288 return ShaderKind::geometry; 289 case Triple::Hull: 290 return ShaderKind::hull; 291 case Triple::Domain: 292 return ShaderKind::domain; 293 case Triple::Compute: 294 return ShaderKind::compute; 295 case Triple::Library: 296 return ShaderKind::library; 297 case Triple::RayGeneration: 298 return ShaderKind::raygeneration; 299 case Triple::Intersection: 300 return ShaderKind::intersection; 301 case Triple::AnyHit: 302 return ShaderKind::anyhit; 303 case Triple::ClosestHit: 304 return ShaderKind::closesthit; 305 case Triple::Miss: 306 return ShaderKind::miss; 307 case Triple::Callable: 308 return ShaderKind::callable; 309 case Triple::Mesh: 310 return ShaderKind::mesh; 311 case Triple::Amplification: 312 return ShaderKind::amplification; 313 default: 314 break; 315 } 316 llvm_unreachable( 317 "Shader Kind Not Found - Invalid DXIL Environment Specified"); 318 } 319 320 static SmallVector<Type *> 321 getArgTypesFromOpParamTypes(ArrayRef<dxil::OpParamType> Types, 322 LLVMContext &Context, Type *OverloadTy) { 323 SmallVector<Type *> ArgTys; 324 ArgTys.emplace_back(Type::getInt32Ty(Context)); 325 for (dxil::OpParamType Ty : Types) 326 ArgTys.emplace_back(getTypeFromOpParamType(Ty, Context, OverloadTy)); 327 return ArgTys; 328 } 329 330 /// Construct DXIL function type. This is the type of a function with 331 /// the following prototype 332 /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>) 333 /// <param-types> are constructed from types in Prop. 334 static FunctionType *getDXILOpFunctionType(dxil::OpCode OpCode, 335 LLVMContext &Context, 336 Type *OverloadTy) { 337 338 switch (OpCode) { 339 #define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...) \ 340 case OpCode: \ 341 return FunctionType::get( \ 342 getTypeFromOpParamType(RetType, Context, OverloadTy), \ 343 getArgTypesFromOpParamTypes({__VA_ARGS__}, Context, OverloadTy), \ 344 /*isVarArg=*/false); 345 #include "DXILOperation.inc" 346 } 347 llvm_unreachable("Invalid OpCode?"); 348 } 349 350 /// Get index of the property from PropList valid for the most recent 351 /// DXIL version not greater than DXILVer. 352 /// PropList is expected to be sorted in ascending order of DXIL version. 353 template <typename T> 354 static std::optional<size_t> getPropIndex(ArrayRef<T> PropList, 355 const VersionTuple DXILVer) { 356 size_t Index = PropList.size() - 1; 357 for (auto Iter = PropList.rbegin(); Iter != PropList.rend(); 358 Iter++, Index--) { 359 const T &Prop = *Iter; 360 if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <= 361 DXILVer) { 362 return Index; 363 } 364 } 365 return std::nullopt; 366 } 367 368 // Helper function to pack an OpCode and VersionTuple into a uint64_t for use 369 // in a switch statement 370 constexpr static uint64_t computeSwitchEnum(dxil::OpCode OpCode, 371 uint16_t VersionMajor, 372 uint16_t VersionMinor) { 373 uint64_t OpCodePack = (uint64_t)OpCode; 374 return (OpCodePack << 32) | (VersionMajor << 16) | VersionMinor; 375 } 376 377 // Retreive all the set attributes for a DXIL OpCode given the targeted 378 // DXILVersion 379 static dxil::Attributes getDXILAttributes(dxil::OpCode OpCode, 380 VersionTuple DXILVersion) { 381 // Instantiate all versions to iterate through 382 SmallVector<Version> Versions = { 383 #define DXIL_VERSION(MAJOR, MINOR) {MAJOR, MINOR}, 384 #include "DXILOperation.inc" 385 }; 386 387 dxil::Attributes Attributes; 388 for (auto Version : Versions) { 389 if (DXILVersion < VersionTuple(Version.Major, Version.Minor)) 390 continue; 391 392 // Switch through and match an OpCode with the specific version and set the 393 // corresponding flag(s) if available 394 switch (computeSwitchEnum(OpCode, Version.Major, Version.Minor)) { 395 #define DXIL_OP_ATTRIBUTES(OpCode, VersionMajor, VersionMinor, ...) \ 396 case computeSwitchEnum(OpCode, VersionMajor, VersionMinor): { \ 397 auto Other = dxil::Attributes{__VA_ARGS__}; \ 398 Attributes |= Other; \ 399 break; \ 400 }; 401 #include "DXILOperation.inc" 402 } 403 } 404 return Attributes; 405 } 406 407 // Retreive the set of DXIL Attributes given the version and map them to an 408 // llvm function attribute that is set onto the instruction 409 static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode, 410 VersionTuple DXILVersion) { 411 dxil::Attributes Attributes = getDXILAttributes(OpCode, DXILVersion); 412 if (Attributes.ReadNone) 413 CI->setDoesNotAccessMemory(); 414 if (Attributes.ReadOnly) 415 CI->setOnlyReadsMemory(); 416 if (Attributes.NoReturn) 417 CI->setDoesNotReturn(); 418 if (Attributes.NoDuplicate) 419 CI->setCannotDuplicate(); 420 return; 421 } 422 423 namespace llvm { 424 namespace dxil { 425 426 // No extra checks on TargetTriple need be performed to verify that the 427 // Triple is well-formed or that the target is supported since these checks 428 // would have been done at the time the module M is constructed in the earlier 429 // stages of compilation. 430 DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) { 431 Triple TT(Triple(M.getTargetTriple())); 432 DXILVersion = TT.getDXILVersion(); 433 ShaderStage = TT.getEnvironment(); 434 // Ensure Environment type is known 435 if (ShaderStage == Triple::UnknownEnvironment) { 436 report_fatal_error( 437 Twine(DXILVersion.getAsString()) + 438 ": Unknown Compilation Target Shader Stage specified ", 439 /*gen_crash_diag*/ false); 440 } 441 } 442 443 static Error makeOpError(dxil::OpCode OpCode, Twine Msg) { 444 return make_error<StringError>( 445 Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg, 446 inconvertibleErrorCode()); 447 } 448 449 Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, 450 ArrayRef<Value *> Args, 451 const Twine &Name, 452 Type *RetTy) { 453 const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 454 455 Type *OverloadTy = nullptr; 456 if (Prop->OverloadParamIndex == 0) { 457 if (!RetTy) 458 return makeOpError(OpCode, "Op overloaded on unknown return type"); 459 OverloadTy = RetTy; 460 } else if (Prop->OverloadParamIndex > 0) { 461 // The index counts including the return type 462 unsigned ArgIndex = Prop->OverloadParamIndex - 1; 463 if (static_cast<unsigned>(ArgIndex) >= Args.size()) 464 return makeOpError(OpCode, "Wrong number of arguments"); 465 OverloadTy = Args[ArgIndex]->getType(); 466 } 467 468 FunctionType *DXILOpFT = 469 getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy); 470 471 std::optional<size_t> OlIndexOrErr = 472 getPropIndex(ArrayRef(Prop->Overloads), DXILVersion); 473 if (!OlIndexOrErr.has_value()) 474 return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") + 475 DXILVersion.getAsString()); 476 477 uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys; 478 479 OverloadKind Kind = getOverloadKind(OverloadTy); 480 481 // Check if the operation supports overload types and OverloadTy is valid 482 // per the specified types for the operation 483 if ((ValidTyMask != OverloadKind::UNDEFINED) && 484 (ValidTyMask & (uint16_t)Kind) == 0) 485 return makeOpError(OpCode, "Invalid overload type"); 486 487 // Perform necessary checks to ensure Opcode is valid in the targeted shader 488 // kind 489 std::optional<size_t> StIndexOrErr = 490 getPropIndex(ArrayRef(Prop->Stages), DXILVersion); 491 if (!StIndexOrErr.has_value()) 492 return makeOpError(OpCode, Twine("No valid stage for DXIL version ") + 493 DXILVersion.getAsString()); 494 495 uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages; 496 497 // Ensure valid shader stage properties are specified 498 if (ValidShaderKindMask == ShaderKind::removed) 499 return makeOpError(OpCode, "Operation has been removed"); 500 501 // Shader stage need not be validated since getShaderKindEnum() fails 502 // for unknown shader stage. 503 504 // Verify the target shader stage is valid for the DXIL operation 505 ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage); 506 if (!(ValidShaderKindMask & ModuleStagekind)) 507 return makeOpError(OpCode, "Invalid stage"); 508 509 std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); 510 FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); 511 512 // We need to inject the opcode as the first argument. 513 SmallVector<Value *> OpArgs; 514 OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode))); 515 OpArgs.append(Args.begin(), Args.end()); 516 517 // Create the function call instruction 518 CallInst *CI = IRB.CreateCall(DXILFn, OpArgs, Name); 519 520 // We then need to attach available function attributes 521 setDXILAttributes(CI, OpCode, DXILVersion); 522 523 return CI; 524 } 525 526 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args, 527 const Twine &Name, Type *RetTy) { 528 Expected<CallInst *> Result = tryCreateOp(OpCode, Args, Name, RetTy); 529 if (Error E = Result.takeError()) 530 llvm_unreachable("Invalid arguments for operation"); 531 return *Result; 532 } 533 534 StructType *DXILOpBuilder::getResRetType(Type *ElementTy) { 535 return ::getResRetType(ElementTy); 536 } 537 538 StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) { 539 return ::getSplitDoubleType(Context); 540 } 541 542 StructType *DXILOpBuilder::getHandleType() { 543 return ::getHandleType(IRB.getContext()); 544 } 545 546 Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound, 547 uint32_t SpaceID, dxil::ResourceClass RC) { 548 Type *Int32Ty = IRB.getInt32Ty(); 549 Type *Int8Ty = IRB.getInt8Ty(); 550 return ConstantStruct::get( 551 getResBindType(IRB.getContext()), 552 {ConstantInt::get(Int32Ty, LowerBound), 553 ConstantInt::get(Int32Ty, UpperBound), 554 ConstantInt::get(Int32Ty, SpaceID), 555 ConstantInt::get(Int8Ty, llvm::to_underlying(RC))}); 556 } 557 558 Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) { 559 Type *Int32Ty = IRB.getInt32Ty(); 560 return ConstantStruct::get( 561 getResPropsType(IRB.getContext()), 562 {ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)}); 563 } 564 565 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { 566 return ::getOpCodeName(DXILOp); 567 } 568 } // namespace dxil 569 } // namespace llvm 570