xref: /llvm-project/llvm/lib/Target/DirectX/DXILOpBuilder.cpp (revision 011b618644113996e2c0a8e57db40f89d20878e3)
1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains class to help build DXIL op functions.
10 //===----------------------------------------------------------------------===//
11 
12 #include "DXILOpBuilder.h"
13 #include "DXILConstants.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/Support/DXILABI.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include <optional>
18 
19 using namespace llvm;
20 using namespace llvm::dxil;
21 
22 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23 
24 namespace {
25 enum OverloadKind : uint16_t {
26   UNDEFINED = 0,
27   VOID = 1,
28   HALF = 1 << 1,
29   FLOAT = 1 << 2,
30   DOUBLE = 1 << 3,
31   I1 = 1 << 4,
32   I8 = 1 << 5,
33   I16 = 1 << 6,
34   I32 = 1 << 7,
35   I64 = 1 << 8,
36   UserDefineType = 1 << 9,
37   ObjectType = 1 << 10,
38 };
39 struct Version {
40   unsigned Major = 0;
41   unsigned Minor = 0;
42 };
43 
44 struct OpOverload {
45   Version DXILVersion;
46   uint16_t ValidTys;
47 };
48 } // namespace
49 
50 struct OpStage {
51   Version DXILVersion;
52   uint32_t ValidStages;
53 };
54 
55 static const char *getOverloadTypeName(OverloadKind Kind) {
56   switch (Kind) {
57   case OverloadKind::HALF:
58     return "f16";
59   case OverloadKind::FLOAT:
60     return "f32";
61   case OverloadKind::DOUBLE:
62     return "f64";
63   case OverloadKind::I1:
64     return "i1";
65   case OverloadKind::I8:
66     return "i8";
67   case OverloadKind::I16:
68     return "i16";
69   case OverloadKind::I32:
70     return "i32";
71   case OverloadKind::I64:
72     return "i64";
73   case OverloadKind::VOID:
74   case OverloadKind::UNDEFINED:
75     return "void";
76   case OverloadKind::ObjectType:
77   case OverloadKind::UserDefineType:
78     break;
79   }
80   llvm_unreachable("invalid overload type for name");
81 }
82 
83 static OverloadKind getOverloadKind(Type *Ty) {
84   if (!Ty)
85     return OverloadKind::VOID;
86 
87   Type::TypeID T = Ty->getTypeID();
88   switch (T) {
89   case Type::VoidTyID:
90     return OverloadKind::VOID;
91   case Type::HalfTyID:
92     return OverloadKind::HALF;
93   case Type::FloatTyID:
94     return OverloadKind::FLOAT;
95   case Type::DoubleTyID:
96     return OverloadKind::DOUBLE;
97   case Type::IntegerTyID: {
98     IntegerType *ITy = cast<IntegerType>(Ty);
99     unsigned Bits = ITy->getBitWidth();
100     switch (Bits) {
101     case 1:
102       return OverloadKind::I1;
103     case 8:
104       return OverloadKind::I8;
105     case 16:
106       return OverloadKind::I16;
107     case 32:
108       return OverloadKind::I32;
109     case 64:
110       return OverloadKind::I64;
111     default:
112       llvm_unreachable("invalid overload type");
113       return OverloadKind::VOID;
114     }
115   }
116   case Type::PointerTyID:
117     return OverloadKind::UserDefineType;
118   case Type::StructTyID: {
119     // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
120     // how we're handling overloads and remove the `OverloadKind` proxy enum.
121     StructType *ST = cast<StructType>(Ty);
122     return getOverloadKind(ST->getElementType(0));
123   }
124   default:
125     return OverloadKind::UNDEFINED;
126   }
127 }
128 
129 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
130   if (Kind < OverloadKind::UserDefineType) {
131     return getOverloadTypeName(Kind);
132   } else if (Kind == OverloadKind::UserDefineType) {
133     StructType *ST = cast<StructType>(Ty);
134     return ST->getStructName().str();
135   } else if (Kind == OverloadKind::ObjectType) {
136     StructType *ST = cast<StructType>(Ty);
137     return ST->getStructName().str();
138   } else {
139     std::string Str;
140     raw_string_ostream OS(Str);
141     Ty->print(OS);
142     return OS.str();
143   }
144 }
145 
146 // Static properties.
147 struct OpCodeProperty {
148   dxil::OpCode OpCode;
149   // Offset in DXILOpCodeNameTable.
150   unsigned OpCodeNameOffset;
151   dxil::OpCodeClass OpCodeClass;
152   // Offset in DXILOpCodeClassNameTable.
153   unsigned OpCodeClassNameOffset;
154   llvm::SmallVector<OpOverload> Overloads;
155   llvm::SmallVector<OpStage> Stages;
156   int OverloadParamIndex; // parameter index which control the overload.
157                           // When < 0, should be only 1 overload type.
158 };
159 
160 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
161 // getOpCodeParameterKind which generated by tableGen.
162 #define DXIL_OP_OPERATION_TABLE
163 #include "DXILOperation.inc"
164 #undef DXIL_OP_OPERATION_TABLE
165 
166 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
167                                          const OpCodeProperty &Prop) {
168   if (Kind == OverloadKind::VOID) {
169     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
170   }
171   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
172           getTypeName(Kind, Ty))
173       .str();
174 }
175 
176 static std::string constructOverloadTypeName(OverloadKind Kind,
177                                              StringRef TypeName) {
178   if (Kind == OverloadKind::VOID)
179     return TypeName.str();
180 
181   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
182   return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
183 }
184 
185 static StructType *getOrCreateStructType(StringRef Name,
186                                          ArrayRef<Type *> EltTys,
187                                          LLVMContext &Ctx) {
188   StructType *ST = StructType::getTypeByName(Ctx, Name);
189   if (ST)
190     return ST;
191 
192   return StructType::create(Ctx, EltTys, Name);
193 }
194 
195 static StructType *getResRetType(Type *ElementTy) {
196   LLVMContext &Ctx = ElementTy->getContext();
197   OverloadKind Kind = getOverloadKind(ElementTy);
198   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
199   Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
200                          Type::getInt32Ty(Ctx)};
201   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
202 }
203 
204 static StructType *getHandleType(LLVMContext &Ctx) {
205   return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
206                                Ctx);
207 }
208 
209 static StructType *getResBindType(LLVMContext &Context) {
210   if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind"))
211     return ST;
212   Type *Int32Ty = Type::getInt32Ty(Context);
213   Type *Int8Ty = Type::getInt8Ty(Context);
214   return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty},
215                             "dx.types.ResBind");
216 }
217 
218 static StructType *getResPropsType(LLVMContext &Context) {
219   if (auto *ST =
220           StructType::getTypeByName(Context, "dx.types.ResourceProperties"))
221     return ST;
222   Type *Int32Ty = Type::getInt32Ty(Context);
223   return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
224 }
225 
226 static StructType *getSplitDoubleType(LLVMContext &Context) {
227   if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
228     return ST;
229   Type *Int32Ty = Type::getInt32Ty(Context);
230   return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
231 }
232 
233 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
234                                     Type *OverloadTy) {
235   switch (Kind) {
236   case OpParamType::VoidTy:
237     return Type::getVoidTy(Ctx);
238   case OpParamType::HalfTy:
239     return Type::getHalfTy(Ctx);
240   case OpParamType::FloatTy:
241     return Type::getFloatTy(Ctx);
242   case OpParamType::DoubleTy:
243     return Type::getDoubleTy(Ctx);
244   case OpParamType::Int1Ty:
245     return Type::getInt1Ty(Ctx);
246   case OpParamType::Int8Ty:
247     return Type::getInt8Ty(Ctx);
248   case OpParamType::Int16Ty:
249     return Type::getInt16Ty(Ctx);
250   case OpParamType::Int32Ty:
251     return Type::getInt32Ty(Ctx);
252   case OpParamType::Int64Ty:
253     return Type::getInt64Ty(Ctx);
254   case OpParamType::OverloadTy:
255     return OverloadTy;
256   case OpParamType::ResRetHalfTy:
257     return getResRetType(Type::getHalfTy(Ctx));
258   case OpParamType::ResRetFloatTy:
259     return getResRetType(Type::getFloatTy(Ctx));
260   case OpParamType::ResRetDoubleTy:
261     return getResRetType(Type::getDoubleTy(Ctx));
262   case OpParamType::ResRetInt16Ty:
263     return getResRetType(Type::getInt16Ty(Ctx));
264   case OpParamType::ResRetInt32Ty:
265     return getResRetType(Type::getInt32Ty(Ctx));
266   case OpParamType::ResRetInt64Ty:
267     return getResRetType(Type::getInt64Ty(Ctx));
268   case OpParamType::HandleTy:
269     return getHandleType(Ctx);
270   case OpParamType::ResBindTy:
271     return getResBindType(Ctx);
272   case OpParamType::ResPropsTy:
273     return getResPropsType(Ctx);
274   case OpParamType::SplitDoubleTy:
275     return getSplitDoubleType(Ctx);
276   }
277   llvm_unreachable("Invalid parameter kind");
278   return nullptr;
279 }
280 
281 static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) {
282   switch (EnvType) {
283   case Triple::Pixel:
284     return ShaderKind::pixel;
285   case Triple::Vertex:
286     return ShaderKind::vertex;
287   case Triple::Geometry:
288     return ShaderKind::geometry;
289   case Triple::Hull:
290     return ShaderKind::hull;
291   case Triple::Domain:
292     return ShaderKind::domain;
293   case Triple::Compute:
294     return ShaderKind::compute;
295   case Triple::Library:
296     return ShaderKind::library;
297   case Triple::RayGeneration:
298     return ShaderKind::raygeneration;
299   case Triple::Intersection:
300     return ShaderKind::intersection;
301   case Triple::AnyHit:
302     return ShaderKind::anyhit;
303   case Triple::ClosestHit:
304     return ShaderKind::closesthit;
305   case Triple::Miss:
306     return ShaderKind::miss;
307   case Triple::Callable:
308     return ShaderKind::callable;
309   case Triple::Mesh:
310     return ShaderKind::mesh;
311   case Triple::Amplification:
312     return ShaderKind::amplification;
313   default:
314     break;
315   }
316   llvm_unreachable(
317       "Shader Kind Not Found - Invalid DXIL Environment Specified");
318 }
319 
320 static SmallVector<Type *>
321 getArgTypesFromOpParamTypes(ArrayRef<dxil::OpParamType> Types,
322                             LLVMContext &Context, Type *OverloadTy) {
323   SmallVector<Type *> ArgTys;
324   ArgTys.emplace_back(Type::getInt32Ty(Context));
325   for (dxil::OpParamType Ty : Types)
326     ArgTys.emplace_back(getTypeFromOpParamType(Ty, Context, OverloadTy));
327   return ArgTys;
328 }
329 
330 /// Construct DXIL function type. This is the type of a function with
331 /// the following prototype
332 ///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
333 /// <param-types> are constructed from types in Prop.
334 static FunctionType *getDXILOpFunctionType(dxil::OpCode OpCode,
335                                            LLVMContext &Context,
336                                            Type *OverloadTy) {
337 
338   switch (OpCode) {
339 #define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...)                            \
340   case OpCode:                                                                 \
341     return FunctionType::get(                                                  \
342         getTypeFromOpParamType(RetType, Context, OverloadTy),                  \
343         getArgTypesFromOpParamTypes({__VA_ARGS__}, Context, OverloadTy),       \
344         /*isVarArg=*/false);
345 #include "DXILOperation.inc"
346   }
347   llvm_unreachable("Invalid OpCode?");
348 }
349 
350 /// Get index of the property from PropList valid for the most recent
351 /// DXIL version not greater than DXILVer.
352 /// PropList is expected to be sorted in ascending order of DXIL version.
353 template <typename T>
354 static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
355                                           const VersionTuple DXILVer) {
356   size_t Index = PropList.size() - 1;
357   for (auto Iter = PropList.rbegin(); Iter != PropList.rend();
358        Iter++, Index--) {
359     const T &Prop = *Iter;
360     if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <=
361         DXILVer) {
362       return Index;
363     }
364   }
365   return std::nullopt;
366 }
367 
368 // Helper function to pack an OpCode and VersionTuple into a uint64_t for use
369 // in a switch statement
370 constexpr static uint64_t computeSwitchEnum(dxil::OpCode OpCode,
371                                             uint16_t VersionMajor,
372                                             uint16_t VersionMinor) {
373   uint64_t OpCodePack = (uint64_t)OpCode;
374   return (OpCodePack << 32) | (VersionMajor << 16) | VersionMinor;
375 }
376 
377 // Retreive all the set attributes for a DXIL OpCode given the targeted
378 // DXILVersion
379 static dxil::Attributes getDXILAttributes(dxil::OpCode OpCode,
380                                           VersionTuple DXILVersion) {
381   // Instantiate all versions to iterate through
382   SmallVector<Version> Versions = {
383 #define DXIL_VERSION(MAJOR, MINOR) {MAJOR, MINOR},
384 #include "DXILOperation.inc"
385   };
386 
387   dxil::Attributes Attributes;
388   for (auto Version : Versions) {
389     if (DXILVersion < VersionTuple(Version.Major, Version.Minor))
390       continue;
391 
392     // Switch through and match an OpCode with the specific version and set the
393     // corresponding flag(s) if available
394     switch (computeSwitchEnum(OpCode, Version.Major, Version.Minor)) {
395 #define DXIL_OP_ATTRIBUTES(OpCode, VersionMajor, VersionMinor, ...)            \
396   case computeSwitchEnum(OpCode, VersionMajor, VersionMinor): {                \
397     auto Other = dxil::Attributes{__VA_ARGS__};                                \
398     Attributes |= Other;                                                       \
399     break;                                                                     \
400   };
401 #include "DXILOperation.inc"
402     }
403   }
404   return Attributes;
405 }
406 
407 // Retreive the set of DXIL Attributes given the version and map them to an
408 // llvm function attribute that is set onto the instruction
409 static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
410                               VersionTuple DXILVersion) {
411   dxil::Attributes Attributes = getDXILAttributes(OpCode, DXILVersion);
412   if (Attributes.ReadNone)
413     CI->setDoesNotAccessMemory();
414   if (Attributes.ReadOnly)
415     CI->setOnlyReadsMemory();
416   if (Attributes.NoReturn)
417     CI->setDoesNotReturn();
418   if (Attributes.NoDuplicate)
419     CI->setCannotDuplicate();
420   return;
421 }
422 
423 namespace llvm {
424 namespace dxil {
425 
426 // No extra checks on TargetTriple need be performed to verify that the
427 // Triple is well-formed or that the target is supported since these checks
428 // would have been done at the time the module M is constructed in the earlier
429 // stages of compilation.
430 DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
431   Triple TT(Triple(M.getTargetTriple()));
432   DXILVersion = TT.getDXILVersion();
433   ShaderStage = TT.getEnvironment();
434   // Ensure Environment type is known
435   if (ShaderStage == Triple::UnknownEnvironment) {
436     report_fatal_error(
437         Twine(DXILVersion.getAsString()) +
438             ": Unknown Compilation Target Shader Stage specified ",
439         /*gen_crash_diag*/ false);
440   }
441 }
442 
443 static Error makeOpError(dxil::OpCode OpCode, Twine Msg) {
444   return make_error<StringError>(
445       Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg,
446       inconvertibleErrorCode());
447 }
448 
449 Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
450                                                 ArrayRef<Value *> Args,
451                                                 const Twine &Name,
452                                                 Type *RetTy) {
453   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
454 
455   Type *OverloadTy = nullptr;
456   if (Prop->OverloadParamIndex == 0) {
457     if (!RetTy)
458       return makeOpError(OpCode, "Op overloaded on unknown return type");
459     OverloadTy = RetTy;
460   } else if (Prop->OverloadParamIndex > 0) {
461     // The index counts including the return type
462     unsigned ArgIndex = Prop->OverloadParamIndex - 1;
463     if (static_cast<unsigned>(ArgIndex) >= Args.size())
464       return makeOpError(OpCode, "Wrong number of arguments");
465     OverloadTy = Args[ArgIndex]->getType();
466   }
467 
468   FunctionType *DXILOpFT =
469       getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
470 
471   std::optional<size_t> OlIndexOrErr =
472       getPropIndex(ArrayRef(Prop->Overloads), DXILVersion);
473   if (!OlIndexOrErr.has_value())
474     return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") +
475                                    DXILVersion.getAsString());
476 
477   uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;
478 
479   OverloadKind Kind = getOverloadKind(OverloadTy);
480 
481   // Check if the operation supports overload types and OverloadTy is valid
482   // per the specified types for the operation
483   if ((ValidTyMask != OverloadKind::UNDEFINED) &&
484       (ValidTyMask & (uint16_t)Kind) == 0)
485     return makeOpError(OpCode, "Invalid overload type");
486 
487   // Perform necessary checks to ensure Opcode is valid in the targeted shader
488   // kind
489   std::optional<size_t> StIndexOrErr =
490       getPropIndex(ArrayRef(Prop->Stages), DXILVersion);
491   if (!StIndexOrErr.has_value())
492     return makeOpError(OpCode, Twine("No valid stage for DXIL version ") +
493                                    DXILVersion.getAsString());
494 
495   uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages;
496 
497   // Ensure valid shader stage properties are specified
498   if (ValidShaderKindMask == ShaderKind::removed)
499     return makeOpError(OpCode, "Operation has been removed");
500 
501   // Shader stage need not be validated since getShaderKindEnum() fails
502   // for unknown shader stage.
503 
504   // Verify the target shader stage is valid for the DXIL operation
505   ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage);
506   if (!(ValidShaderKindMask & ModuleStagekind))
507     return makeOpError(OpCode, "Invalid stage");
508 
509   std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
510   FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
511 
512   // We need to inject the opcode as the first argument.
513   SmallVector<Value *> OpArgs;
514   OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
515   OpArgs.append(Args.begin(), Args.end());
516 
517   // Create the function call instruction
518   CallInst *CI = IRB.CreateCall(DXILFn, OpArgs, Name);
519 
520   // We then need to attach available function attributes
521   setDXILAttributes(CI, OpCode, DXILVersion);
522 
523   return CI;
524 }
525 
526 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
527                                   const Twine &Name, Type *RetTy) {
528   Expected<CallInst *> Result = tryCreateOp(OpCode, Args, Name, RetTy);
529   if (Error E = Result.takeError())
530     llvm_unreachable("Invalid arguments for operation");
531   return *Result;
532 }
533 
534 StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
535   return ::getResRetType(ElementTy);
536 }
537 
538 StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
539   return ::getSplitDoubleType(Context);
540 }
541 
542 StructType *DXILOpBuilder::getHandleType() {
543   return ::getHandleType(IRB.getContext());
544 }
545 
546 Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound,
547                                     uint32_t SpaceID, dxil::ResourceClass RC) {
548   Type *Int32Ty = IRB.getInt32Ty();
549   Type *Int8Ty = IRB.getInt8Ty();
550   return ConstantStruct::get(
551       getResBindType(IRB.getContext()),
552       {ConstantInt::get(Int32Ty, LowerBound),
553        ConstantInt::get(Int32Ty, UpperBound),
554        ConstantInt::get(Int32Ty, SpaceID),
555        ConstantInt::get(Int8Ty, llvm::to_underlying(RC))});
556 }
557 
558 Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) {
559   Type *Int32Ty = IRB.getInt32Ty();
560   return ConstantStruct::get(
561       getResPropsType(IRB.getContext()),
562       {ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)});
563 }
564 
565 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
566   return ::getOpCodeName(DXILOp);
567 }
568 } // namespace dxil
569 } // namespace llvm
570