xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpBuilder.cpp (revision 972a253a57b6f144b0e4a3e2080a2a0076ec55a0)
1*972a253aSDimitry Andric //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2*972a253aSDimitry Andric //
3*972a253aSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*972a253aSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*972a253aSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*972a253aSDimitry Andric //
7*972a253aSDimitry Andric //===----------------------------------------------------------------------===//
8*972a253aSDimitry Andric ///
9*972a253aSDimitry Andric /// \file This file contains class to help build DXIL op functions.
10*972a253aSDimitry Andric //===----------------------------------------------------------------------===//
11*972a253aSDimitry Andric 
12*972a253aSDimitry Andric #include "DXILOpBuilder.h"
13*972a253aSDimitry Andric #include "DXILConstants.h"
14*972a253aSDimitry Andric #include "llvm/IR/IRBuilder.h"
15*972a253aSDimitry Andric #include "llvm/IR/Module.h"
16*972a253aSDimitry Andric #include "llvm/Support/DXILOperationCommon.h"
17*972a253aSDimitry Andric #include "llvm/Support/ErrorHandling.h"
18*972a253aSDimitry Andric 
19*972a253aSDimitry Andric using namespace llvm;
20*972a253aSDimitry Andric using namespace llvm::DXIL;
21*972a253aSDimitry Andric 
22*972a253aSDimitry Andric constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23*972a253aSDimitry Andric 
24*972a253aSDimitry Andric namespace {
25*972a253aSDimitry Andric 
26*972a253aSDimitry Andric enum OverloadKind : uint16_t {
27*972a253aSDimitry Andric   VOID = 1,
28*972a253aSDimitry Andric   HALF = 1 << 1,
29*972a253aSDimitry Andric   FLOAT = 1 << 2,
30*972a253aSDimitry Andric   DOUBLE = 1 << 3,
31*972a253aSDimitry Andric   I1 = 1 << 4,
32*972a253aSDimitry Andric   I8 = 1 << 5,
33*972a253aSDimitry Andric   I16 = 1 << 6,
34*972a253aSDimitry Andric   I32 = 1 << 7,
35*972a253aSDimitry Andric   I64 = 1 << 8,
36*972a253aSDimitry Andric   UserDefineType = 1 << 9,
37*972a253aSDimitry Andric   ObjectType = 1 << 10,
38*972a253aSDimitry Andric };
39*972a253aSDimitry Andric 
40*972a253aSDimitry Andric } // namespace
41*972a253aSDimitry Andric 
42*972a253aSDimitry Andric static const char *getOverloadTypeName(OverloadKind Kind) {
43*972a253aSDimitry Andric   switch (Kind) {
44*972a253aSDimitry Andric   case OverloadKind::HALF:
45*972a253aSDimitry Andric     return "f16";
46*972a253aSDimitry Andric   case OverloadKind::FLOAT:
47*972a253aSDimitry Andric     return "f32";
48*972a253aSDimitry Andric   case OverloadKind::DOUBLE:
49*972a253aSDimitry Andric     return "f64";
50*972a253aSDimitry Andric   case OverloadKind::I1:
51*972a253aSDimitry Andric     return "i1";
52*972a253aSDimitry Andric   case OverloadKind::I8:
53*972a253aSDimitry Andric     return "i8";
54*972a253aSDimitry Andric   case OverloadKind::I16:
55*972a253aSDimitry Andric     return "i16";
56*972a253aSDimitry Andric   case OverloadKind::I32:
57*972a253aSDimitry Andric     return "i32";
58*972a253aSDimitry Andric   case OverloadKind::I64:
59*972a253aSDimitry Andric     return "i64";
60*972a253aSDimitry Andric   case OverloadKind::VOID:
61*972a253aSDimitry Andric   case OverloadKind::ObjectType:
62*972a253aSDimitry Andric   case OverloadKind::UserDefineType:
63*972a253aSDimitry Andric     break;
64*972a253aSDimitry Andric   }
65*972a253aSDimitry Andric   llvm_unreachable("invalid overload type for name");
66*972a253aSDimitry Andric   return "void";
67*972a253aSDimitry Andric }
68*972a253aSDimitry Andric 
69*972a253aSDimitry Andric static OverloadKind getOverloadKind(Type *Ty) {
70*972a253aSDimitry Andric   Type::TypeID T = Ty->getTypeID();
71*972a253aSDimitry Andric   switch (T) {
72*972a253aSDimitry Andric   case Type::VoidTyID:
73*972a253aSDimitry Andric     return OverloadKind::VOID;
74*972a253aSDimitry Andric   case Type::HalfTyID:
75*972a253aSDimitry Andric     return OverloadKind::HALF;
76*972a253aSDimitry Andric   case Type::FloatTyID:
77*972a253aSDimitry Andric     return OverloadKind::FLOAT;
78*972a253aSDimitry Andric   case Type::DoubleTyID:
79*972a253aSDimitry Andric     return OverloadKind::DOUBLE;
80*972a253aSDimitry Andric   case Type::IntegerTyID: {
81*972a253aSDimitry Andric     IntegerType *ITy = cast<IntegerType>(Ty);
82*972a253aSDimitry Andric     unsigned Bits = ITy->getBitWidth();
83*972a253aSDimitry Andric     switch (Bits) {
84*972a253aSDimitry Andric     case 1:
85*972a253aSDimitry Andric       return OverloadKind::I1;
86*972a253aSDimitry Andric     case 8:
87*972a253aSDimitry Andric       return OverloadKind::I8;
88*972a253aSDimitry Andric     case 16:
89*972a253aSDimitry Andric       return OverloadKind::I16;
90*972a253aSDimitry Andric     case 32:
91*972a253aSDimitry Andric       return OverloadKind::I32;
92*972a253aSDimitry Andric     case 64:
93*972a253aSDimitry Andric       return OverloadKind::I64;
94*972a253aSDimitry Andric     default:
95*972a253aSDimitry Andric       llvm_unreachable("invalid overload type");
96*972a253aSDimitry Andric       return OverloadKind::VOID;
97*972a253aSDimitry Andric     }
98*972a253aSDimitry Andric   }
99*972a253aSDimitry Andric   case Type::PointerTyID:
100*972a253aSDimitry Andric     return OverloadKind::UserDefineType;
101*972a253aSDimitry Andric   case Type::StructTyID:
102*972a253aSDimitry Andric     return OverloadKind::ObjectType;
103*972a253aSDimitry Andric   default:
104*972a253aSDimitry Andric     llvm_unreachable("invalid overload type");
105*972a253aSDimitry Andric     return OverloadKind::VOID;
106*972a253aSDimitry Andric   }
107*972a253aSDimitry Andric }
108*972a253aSDimitry Andric 
109*972a253aSDimitry Andric static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110*972a253aSDimitry Andric   if (Kind < OverloadKind::UserDefineType) {
111*972a253aSDimitry Andric     return getOverloadTypeName(Kind);
112*972a253aSDimitry Andric   } else if (Kind == OverloadKind::UserDefineType) {
113*972a253aSDimitry Andric     StructType *ST = cast<StructType>(Ty);
114*972a253aSDimitry Andric     return ST->getStructName().str();
115*972a253aSDimitry Andric   } else if (Kind == OverloadKind::ObjectType) {
116*972a253aSDimitry Andric     StructType *ST = cast<StructType>(Ty);
117*972a253aSDimitry Andric     return ST->getStructName().str();
118*972a253aSDimitry Andric   } else {
119*972a253aSDimitry Andric     std::string Str;
120*972a253aSDimitry Andric     raw_string_ostream OS(Str);
121*972a253aSDimitry Andric     Ty->print(OS);
122*972a253aSDimitry Andric     return OS.str();
123*972a253aSDimitry Andric   }
124*972a253aSDimitry Andric }
125*972a253aSDimitry Andric 
126*972a253aSDimitry Andric // Static properties.
127*972a253aSDimitry Andric struct OpCodeProperty {
128*972a253aSDimitry Andric   DXIL::OpCode OpCode;
129*972a253aSDimitry Andric   // Offset in DXILOpCodeNameTable.
130*972a253aSDimitry Andric   unsigned OpCodeNameOffset;
131*972a253aSDimitry Andric   DXIL::OpCodeClass OpCodeClass;
132*972a253aSDimitry Andric   // Offset in DXILOpCodeClassNameTable.
133*972a253aSDimitry Andric   unsigned OpCodeClassNameOffset;
134*972a253aSDimitry Andric   uint16_t OverloadTys;
135*972a253aSDimitry Andric   llvm::Attribute::AttrKind FuncAttr;
136*972a253aSDimitry Andric   int OverloadParamIndex;        // parameter index which control the overload.
137*972a253aSDimitry Andric                                  // When < 0, should be only 1 overload type.
138*972a253aSDimitry Andric   unsigned NumOfParameters;      // Number of parameters include return value.
139*972a253aSDimitry Andric   unsigned ParameterTableOffset; // Offset in ParameterTable.
140*972a253aSDimitry Andric };
141*972a253aSDimitry Andric 
142*972a253aSDimitry Andric // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143*972a253aSDimitry Andric // getOpCodeParameterKind which generated by tableGen.
144*972a253aSDimitry Andric #define DXIL_OP_OPERATION_TABLE
145*972a253aSDimitry Andric #include "DXILOperation.inc"
146*972a253aSDimitry Andric #undef DXIL_OP_OPERATION_TABLE
147*972a253aSDimitry Andric 
148*972a253aSDimitry Andric static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149*972a253aSDimitry Andric                                          const OpCodeProperty &Prop) {
150*972a253aSDimitry Andric   if (Kind == OverloadKind::VOID) {
151*972a253aSDimitry Andric     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152*972a253aSDimitry Andric   }
153*972a253aSDimitry Andric   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154*972a253aSDimitry Andric           getTypeName(Kind, Ty))
155*972a253aSDimitry Andric       .str();
156*972a253aSDimitry Andric }
157*972a253aSDimitry Andric 
158*972a253aSDimitry Andric static std::string constructOverloadTypeName(OverloadKind Kind,
159*972a253aSDimitry Andric                                              StringRef TypeName) {
160*972a253aSDimitry Andric   if (Kind == OverloadKind::VOID)
161*972a253aSDimitry Andric     return TypeName.str();
162*972a253aSDimitry Andric 
163*972a253aSDimitry Andric   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164*972a253aSDimitry Andric   return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165*972a253aSDimitry Andric }
166*972a253aSDimitry Andric 
167*972a253aSDimitry Andric static StructType *getOrCreateStructType(StringRef Name,
168*972a253aSDimitry Andric                                          ArrayRef<Type *> EltTys,
169*972a253aSDimitry Andric                                          LLVMContext &Ctx) {
170*972a253aSDimitry Andric   StructType *ST = StructType::getTypeByName(Ctx, Name);
171*972a253aSDimitry Andric   if (ST)
172*972a253aSDimitry Andric     return ST;
173*972a253aSDimitry Andric 
174*972a253aSDimitry Andric   return StructType::create(Ctx, EltTys, Name);
175*972a253aSDimitry Andric }
176*972a253aSDimitry Andric 
177*972a253aSDimitry Andric static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178*972a253aSDimitry Andric   OverloadKind Kind = getOverloadKind(OverloadTy);
179*972a253aSDimitry Andric   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180*972a253aSDimitry Andric   Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181*972a253aSDimitry Andric                          Type::getInt32Ty(Ctx)};
182*972a253aSDimitry Andric   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183*972a253aSDimitry Andric }
184*972a253aSDimitry Andric 
185*972a253aSDimitry Andric static StructType *getHandleType(LLVMContext &Ctx) {
186*972a253aSDimitry Andric   return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
187*972a253aSDimitry Andric }
188*972a253aSDimitry Andric 
189*972a253aSDimitry Andric static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
190*972a253aSDimitry Andric   auto &Ctx = OverloadTy->getContext();
191*972a253aSDimitry Andric   switch (Kind) {
192*972a253aSDimitry Andric   case ParameterKind::VOID:
193*972a253aSDimitry Andric     return Type::getVoidTy(Ctx);
194*972a253aSDimitry Andric   case ParameterKind::HALF:
195*972a253aSDimitry Andric     return Type::getHalfTy(Ctx);
196*972a253aSDimitry Andric   case ParameterKind::FLOAT:
197*972a253aSDimitry Andric     return Type::getFloatTy(Ctx);
198*972a253aSDimitry Andric   case ParameterKind::DOUBLE:
199*972a253aSDimitry Andric     return Type::getDoubleTy(Ctx);
200*972a253aSDimitry Andric   case ParameterKind::I1:
201*972a253aSDimitry Andric     return Type::getInt1Ty(Ctx);
202*972a253aSDimitry Andric   case ParameterKind::I8:
203*972a253aSDimitry Andric     return Type::getInt8Ty(Ctx);
204*972a253aSDimitry Andric   case ParameterKind::I16:
205*972a253aSDimitry Andric     return Type::getInt16Ty(Ctx);
206*972a253aSDimitry Andric   case ParameterKind::I32:
207*972a253aSDimitry Andric     return Type::getInt32Ty(Ctx);
208*972a253aSDimitry Andric   case ParameterKind::I64:
209*972a253aSDimitry Andric     return Type::getInt64Ty(Ctx);
210*972a253aSDimitry Andric   case ParameterKind::OVERLOAD:
211*972a253aSDimitry Andric     return OverloadTy;
212*972a253aSDimitry Andric   case ParameterKind::RESOURCE_RET:
213*972a253aSDimitry Andric     return getResRetType(OverloadTy, Ctx);
214*972a253aSDimitry Andric   case ParameterKind::DXIL_HANDLE:
215*972a253aSDimitry Andric     return getHandleType(Ctx);
216*972a253aSDimitry Andric   default:
217*972a253aSDimitry Andric     break;
218*972a253aSDimitry Andric   }
219*972a253aSDimitry Andric   llvm_unreachable("Invalid parameter kind");
220*972a253aSDimitry Andric   return nullptr;
221*972a253aSDimitry Andric }
222*972a253aSDimitry Andric 
223*972a253aSDimitry Andric static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
224*972a253aSDimitry Andric                                            Type *OverloadTy) {
225*972a253aSDimitry Andric   SmallVector<Type *> ArgTys;
226*972a253aSDimitry Andric 
227*972a253aSDimitry Andric   auto ParamKinds = getOpCodeParameterKind(*Prop);
228*972a253aSDimitry Andric 
229*972a253aSDimitry Andric   for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
230*972a253aSDimitry Andric     ParameterKind Kind = ParamKinds[I];
231*972a253aSDimitry Andric     ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
232*972a253aSDimitry Andric   }
233*972a253aSDimitry Andric   return FunctionType::get(
234*972a253aSDimitry Andric       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
235*972a253aSDimitry Andric }
236*972a253aSDimitry Andric 
237*972a253aSDimitry Andric static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp,
238*972a253aSDimitry Andric                                                 Type *OverloadTy, Module &M) {
239*972a253aSDimitry Andric   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
240*972a253aSDimitry Andric 
241*972a253aSDimitry Andric   OverloadKind Kind = getOverloadKind(OverloadTy);
242*972a253aSDimitry Andric   // FIXME: find the issue and report error in clang instead of check it in
243*972a253aSDimitry Andric   // backend.
244*972a253aSDimitry Andric   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
245*972a253aSDimitry Andric     llvm_unreachable("invalid overload");
246*972a253aSDimitry Andric   }
247*972a253aSDimitry Andric 
248*972a253aSDimitry Andric   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
249*972a253aSDimitry Andric   // Dependent on name to dedup.
250*972a253aSDimitry Andric   if (auto *Fn = M.getFunction(FnName))
251*972a253aSDimitry Andric     return FunctionCallee(Fn);
252*972a253aSDimitry Andric 
253*972a253aSDimitry Andric   FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
254*972a253aSDimitry Andric   return M.getOrInsertFunction(FnName, DXILOpFT);
255*972a253aSDimitry Andric }
256*972a253aSDimitry Andric 
257*972a253aSDimitry Andric namespace llvm {
258*972a253aSDimitry Andric namespace DXIL {
259*972a253aSDimitry Andric 
260*972a253aSDimitry Andric CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy,
261*972a253aSDimitry Andric                                           llvm::iterator_range<Use *> Args) {
262*972a253aSDimitry Andric   auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
263*972a253aSDimitry Andric   SmallVector<Value *> FullArgs;
264*972a253aSDimitry Andric   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
265*972a253aSDimitry Andric   FullArgs.append(Args.begin(), Args.end());
266*972a253aSDimitry Andric   return B.CreateCall(Fn, FullArgs);
267*972a253aSDimitry Andric }
268*972a253aSDimitry Andric 
269*972a253aSDimitry Andric Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT,
270*972a253aSDimitry Andric                                    bool NoOpCodeParam) {
271*972a253aSDimitry Andric 
272*972a253aSDimitry Andric   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
273*972a253aSDimitry Andric   if (Prop->OverloadParamIndex < 0) {
274*972a253aSDimitry Andric     auto &Ctx = FT->getContext();
275*972a253aSDimitry Andric     // When only has 1 overload type, just return it.
276*972a253aSDimitry Andric     switch (Prop->OverloadTys) {
277*972a253aSDimitry Andric     case OverloadKind::VOID:
278*972a253aSDimitry Andric       return Type::getVoidTy(Ctx);
279*972a253aSDimitry Andric     case OverloadKind::HALF:
280*972a253aSDimitry Andric       return Type::getHalfTy(Ctx);
281*972a253aSDimitry Andric     case OverloadKind::FLOAT:
282*972a253aSDimitry Andric       return Type::getFloatTy(Ctx);
283*972a253aSDimitry Andric     case OverloadKind::DOUBLE:
284*972a253aSDimitry Andric       return Type::getDoubleTy(Ctx);
285*972a253aSDimitry Andric     case OverloadKind::I1:
286*972a253aSDimitry Andric       return Type::getInt1Ty(Ctx);
287*972a253aSDimitry Andric     case OverloadKind::I8:
288*972a253aSDimitry Andric       return Type::getInt8Ty(Ctx);
289*972a253aSDimitry Andric     case OverloadKind::I16:
290*972a253aSDimitry Andric       return Type::getInt16Ty(Ctx);
291*972a253aSDimitry Andric     case OverloadKind::I32:
292*972a253aSDimitry Andric       return Type::getInt32Ty(Ctx);
293*972a253aSDimitry Andric     case OverloadKind::I64:
294*972a253aSDimitry Andric       return Type::getInt64Ty(Ctx);
295*972a253aSDimitry Andric     default:
296*972a253aSDimitry Andric       llvm_unreachable("invalid overload type");
297*972a253aSDimitry Andric       return nullptr;
298*972a253aSDimitry Andric     }
299*972a253aSDimitry Andric   }
300*972a253aSDimitry Andric 
301*972a253aSDimitry Andric   // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
302*972a253aSDimitry Andric   Type *OverloadType = FT->getReturnType();
303*972a253aSDimitry Andric   if (Prop->OverloadParamIndex != 0) {
304*972a253aSDimitry Andric     // Skip Return Type and Type for DXIL opcode.
305*972a253aSDimitry Andric     const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
306*972a253aSDimitry Andric     OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
307*972a253aSDimitry Andric   }
308*972a253aSDimitry Andric 
309*972a253aSDimitry Andric   auto ParamKinds = getOpCodeParameterKind(*Prop);
310*972a253aSDimitry Andric   auto Kind = ParamKinds[Prop->OverloadParamIndex];
311*972a253aSDimitry Andric   // For ResRet and CBufferRet, OverloadTy is in field of StructType.
312*972a253aSDimitry Andric   if (Kind == ParameterKind::CBUFFER_RET ||
313*972a253aSDimitry Andric       Kind == ParameterKind::RESOURCE_RET) {
314*972a253aSDimitry Andric     auto *ST = cast<StructType>(OverloadType);
315*972a253aSDimitry Andric     OverloadType = ST->getElementType(0);
316*972a253aSDimitry Andric   }
317*972a253aSDimitry Andric   return OverloadType;
318*972a253aSDimitry Andric }
319*972a253aSDimitry Andric 
320*972a253aSDimitry Andric const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) {
321*972a253aSDimitry Andric   return ::getOpCodeName(DXILOp);
322*972a253aSDimitry Andric }
323*972a253aSDimitry Andric } // namespace DXIL
324*972a253aSDimitry Andric } // namespace llvm
325