xref: /openbsd-src/gnu/llvm/llvm/lib/Target/DirectX/DXILOpBuilder.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1*d415bd75Srobert //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2*d415bd75Srobert //
3*d415bd75Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*d415bd75Srobert // See https://llvm.org/LICENSE.txt for license information.
5*d415bd75Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*d415bd75Srobert //
7*d415bd75Srobert //===----------------------------------------------------------------------===//
8*d415bd75Srobert ///
9*d415bd75Srobert /// \file This file contains class to help build DXIL op functions.
10*d415bd75Srobert //===----------------------------------------------------------------------===//
11*d415bd75Srobert 
12*d415bd75Srobert #include "DXILOpBuilder.h"
13*d415bd75Srobert #include "DXILConstants.h"
14*d415bd75Srobert #include "llvm/IR/IRBuilder.h"
15*d415bd75Srobert #include "llvm/IR/Module.h"
16*d415bd75Srobert #include "llvm/Support/DXILOperationCommon.h"
17*d415bd75Srobert #include "llvm/Support/ErrorHandling.h"
18*d415bd75Srobert 
19*d415bd75Srobert using namespace llvm;
20*d415bd75Srobert using namespace llvm::dxil;
21*d415bd75Srobert 
22*d415bd75Srobert constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23*d415bd75Srobert 
24*d415bd75Srobert namespace {
25*d415bd75Srobert 
26*d415bd75Srobert enum OverloadKind : uint16_t {
27*d415bd75Srobert   VOID = 1,
28*d415bd75Srobert   HALF = 1 << 1,
29*d415bd75Srobert   FLOAT = 1 << 2,
30*d415bd75Srobert   DOUBLE = 1 << 3,
31*d415bd75Srobert   I1 = 1 << 4,
32*d415bd75Srobert   I8 = 1 << 5,
33*d415bd75Srobert   I16 = 1 << 6,
34*d415bd75Srobert   I32 = 1 << 7,
35*d415bd75Srobert   I64 = 1 << 8,
36*d415bd75Srobert   UserDefineType = 1 << 9,
37*d415bd75Srobert   ObjectType = 1 << 10,
38*d415bd75Srobert };
39*d415bd75Srobert 
40*d415bd75Srobert } // namespace
41*d415bd75Srobert 
getOverloadTypeName(OverloadKind Kind)42*d415bd75Srobert static const char *getOverloadTypeName(OverloadKind Kind) {
43*d415bd75Srobert   switch (Kind) {
44*d415bd75Srobert   case OverloadKind::HALF:
45*d415bd75Srobert     return "f16";
46*d415bd75Srobert   case OverloadKind::FLOAT:
47*d415bd75Srobert     return "f32";
48*d415bd75Srobert   case OverloadKind::DOUBLE:
49*d415bd75Srobert     return "f64";
50*d415bd75Srobert   case OverloadKind::I1:
51*d415bd75Srobert     return "i1";
52*d415bd75Srobert   case OverloadKind::I8:
53*d415bd75Srobert     return "i8";
54*d415bd75Srobert   case OverloadKind::I16:
55*d415bd75Srobert     return "i16";
56*d415bd75Srobert   case OverloadKind::I32:
57*d415bd75Srobert     return "i32";
58*d415bd75Srobert   case OverloadKind::I64:
59*d415bd75Srobert     return "i64";
60*d415bd75Srobert   case OverloadKind::VOID:
61*d415bd75Srobert   case OverloadKind::ObjectType:
62*d415bd75Srobert   case OverloadKind::UserDefineType:
63*d415bd75Srobert     break;
64*d415bd75Srobert   }
65*d415bd75Srobert   llvm_unreachable("invalid overload type for name");
66*d415bd75Srobert   return "void";
67*d415bd75Srobert }
68*d415bd75Srobert 
getOverloadKind(Type * Ty)69*d415bd75Srobert static OverloadKind getOverloadKind(Type *Ty) {
70*d415bd75Srobert   Type::TypeID T = Ty->getTypeID();
71*d415bd75Srobert   switch (T) {
72*d415bd75Srobert   case Type::VoidTyID:
73*d415bd75Srobert     return OverloadKind::VOID;
74*d415bd75Srobert   case Type::HalfTyID:
75*d415bd75Srobert     return OverloadKind::HALF;
76*d415bd75Srobert   case Type::FloatTyID:
77*d415bd75Srobert     return OverloadKind::FLOAT;
78*d415bd75Srobert   case Type::DoubleTyID:
79*d415bd75Srobert     return OverloadKind::DOUBLE;
80*d415bd75Srobert   case Type::IntegerTyID: {
81*d415bd75Srobert     IntegerType *ITy = cast<IntegerType>(Ty);
82*d415bd75Srobert     unsigned Bits = ITy->getBitWidth();
83*d415bd75Srobert     switch (Bits) {
84*d415bd75Srobert     case 1:
85*d415bd75Srobert       return OverloadKind::I1;
86*d415bd75Srobert     case 8:
87*d415bd75Srobert       return OverloadKind::I8;
88*d415bd75Srobert     case 16:
89*d415bd75Srobert       return OverloadKind::I16;
90*d415bd75Srobert     case 32:
91*d415bd75Srobert       return OverloadKind::I32;
92*d415bd75Srobert     case 64:
93*d415bd75Srobert       return OverloadKind::I64;
94*d415bd75Srobert     default:
95*d415bd75Srobert       llvm_unreachable("invalid overload type");
96*d415bd75Srobert       return OverloadKind::VOID;
97*d415bd75Srobert     }
98*d415bd75Srobert   }
99*d415bd75Srobert   case Type::PointerTyID:
100*d415bd75Srobert     return OverloadKind::UserDefineType;
101*d415bd75Srobert   case Type::StructTyID:
102*d415bd75Srobert     return OverloadKind::ObjectType;
103*d415bd75Srobert   default:
104*d415bd75Srobert     llvm_unreachable("invalid overload type");
105*d415bd75Srobert     return OverloadKind::VOID;
106*d415bd75Srobert   }
107*d415bd75Srobert }
108*d415bd75Srobert 
getTypeName(OverloadKind Kind,Type * Ty)109*d415bd75Srobert static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110*d415bd75Srobert   if (Kind < OverloadKind::UserDefineType) {
111*d415bd75Srobert     return getOverloadTypeName(Kind);
112*d415bd75Srobert   } else if (Kind == OverloadKind::UserDefineType) {
113*d415bd75Srobert     StructType *ST = cast<StructType>(Ty);
114*d415bd75Srobert     return ST->getStructName().str();
115*d415bd75Srobert   } else if (Kind == OverloadKind::ObjectType) {
116*d415bd75Srobert     StructType *ST = cast<StructType>(Ty);
117*d415bd75Srobert     return ST->getStructName().str();
118*d415bd75Srobert   } else {
119*d415bd75Srobert     std::string Str;
120*d415bd75Srobert     raw_string_ostream OS(Str);
121*d415bd75Srobert     Ty->print(OS);
122*d415bd75Srobert     return OS.str();
123*d415bd75Srobert   }
124*d415bd75Srobert }
125*d415bd75Srobert 
126*d415bd75Srobert // Static properties.
127*d415bd75Srobert struct OpCodeProperty {
128*d415bd75Srobert   dxil::OpCode OpCode;
129*d415bd75Srobert   // Offset in DXILOpCodeNameTable.
130*d415bd75Srobert   unsigned OpCodeNameOffset;
131*d415bd75Srobert   dxil::OpCodeClass OpCodeClass;
132*d415bd75Srobert   // Offset in DXILOpCodeClassNameTable.
133*d415bd75Srobert   unsigned OpCodeClassNameOffset;
134*d415bd75Srobert   uint16_t OverloadTys;
135*d415bd75Srobert   llvm::Attribute::AttrKind FuncAttr;
136*d415bd75Srobert   int OverloadParamIndex;        // parameter index which control the overload.
137*d415bd75Srobert                                  // When < 0, should be only 1 overload type.
138*d415bd75Srobert   unsigned NumOfParameters;      // Number of parameters include return value.
139*d415bd75Srobert   unsigned ParameterTableOffset; // Offset in ParameterTable.
140*d415bd75Srobert };
141*d415bd75Srobert 
142*d415bd75Srobert // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143*d415bd75Srobert // getOpCodeParameterKind which generated by tableGen.
144*d415bd75Srobert #define DXIL_OP_OPERATION_TABLE
145*d415bd75Srobert #include "DXILOperation.inc"
146*d415bd75Srobert #undef DXIL_OP_OPERATION_TABLE
147*d415bd75Srobert 
constructOverloadName(OverloadKind Kind,Type * Ty,const OpCodeProperty & Prop)148*d415bd75Srobert static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149*d415bd75Srobert                                          const OpCodeProperty &Prop) {
150*d415bd75Srobert   if (Kind == OverloadKind::VOID) {
151*d415bd75Srobert     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152*d415bd75Srobert   }
153*d415bd75Srobert   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154*d415bd75Srobert           getTypeName(Kind, Ty))
155*d415bd75Srobert       .str();
156*d415bd75Srobert }
157*d415bd75Srobert 
constructOverloadTypeName(OverloadKind Kind,StringRef TypeName)158*d415bd75Srobert static std::string constructOverloadTypeName(OverloadKind Kind,
159*d415bd75Srobert                                              StringRef TypeName) {
160*d415bd75Srobert   if (Kind == OverloadKind::VOID)
161*d415bd75Srobert     return TypeName.str();
162*d415bd75Srobert 
163*d415bd75Srobert   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164*d415bd75Srobert   return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165*d415bd75Srobert }
166*d415bd75Srobert 
getOrCreateStructType(StringRef Name,ArrayRef<Type * > EltTys,LLVMContext & Ctx)167*d415bd75Srobert static StructType *getOrCreateStructType(StringRef Name,
168*d415bd75Srobert                                          ArrayRef<Type *> EltTys,
169*d415bd75Srobert                                          LLVMContext &Ctx) {
170*d415bd75Srobert   StructType *ST = StructType::getTypeByName(Ctx, Name);
171*d415bd75Srobert   if (ST)
172*d415bd75Srobert     return ST;
173*d415bd75Srobert 
174*d415bd75Srobert   return StructType::create(Ctx, EltTys, Name);
175*d415bd75Srobert }
176*d415bd75Srobert 
getResRetType(Type * OverloadTy,LLVMContext & Ctx)177*d415bd75Srobert static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178*d415bd75Srobert   OverloadKind Kind = getOverloadKind(OverloadTy);
179*d415bd75Srobert   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180*d415bd75Srobert   Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181*d415bd75Srobert                          Type::getInt32Ty(Ctx)};
182*d415bd75Srobert   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183*d415bd75Srobert }
184*d415bd75Srobert 
getHandleType(LLVMContext & Ctx)185*d415bd75Srobert static StructType *getHandleType(LLVMContext &Ctx) {
186*d415bd75Srobert   return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
187*d415bd75Srobert }
188*d415bd75Srobert 
getTypeFromParameterKind(ParameterKind Kind,Type * OverloadTy)189*d415bd75Srobert static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
190*d415bd75Srobert   auto &Ctx = OverloadTy->getContext();
191*d415bd75Srobert   switch (Kind) {
192*d415bd75Srobert   case ParameterKind::VOID:
193*d415bd75Srobert     return Type::getVoidTy(Ctx);
194*d415bd75Srobert   case ParameterKind::HALF:
195*d415bd75Srobert     return Type::getHalfTy(Ctx);
196*d415bd75Srobert   case ParameterKind::FLOAT:
197*d415bd75Srobert     return Type::getFloatTy(Ctx);
198*d415bd75Srobert   case ParameterKind::DOUBLE:
199*d415bd75Srobert     return Type::getDoubleTy(Ctx);
200*d415bd75Srobert   case ParameterKind::I1:
201*d415bd75Srobert     return Type::getInt1Ty(Ctx);
202*d415bd75Srobert   case ParameterKind::I8:
203*d415bd75Srobert     return Type::getInt8Ty(Ctx);
204*d415bd75Srobert   case ParameterKind::I16:
205*d415bd75Srobert     return Type::getInt16Ty(Ctx);
206*d415bd75Srobert   case ParameterKind::I32:
207*d415bd75Srobert     return Type::getInt32Ty(Ctx);
208*d415bd75Srobert   case ParameterKind::I64:
209*d415bd75Srobert     return Type::getInt64Ty(Ctx);
210*d415bd75Srobert   case ParameterKind::OVERLOAD:
211*d415bd75Srobert     return OverloadTy;
212*d415bd75Srobert   case ParameterKind::RESOURCE_RET:
213*d415bd75Srobert     return getResRetType(OverloadTy, Ctx);
214*d415bd75Srobert   case ParameterKind::DXIL_HANDLE:
215*d415bd75Srobert     return getHandleType(Ctx);
216*d415bd75Srobert   default:
217*d415bd75Srobert     break;
218*d415bd75Srobert   }
219*d415bd75Srobert   llvm_unreachable("Invalid parameter kind");
220*d415bd75Srobert   return nullptr;
221*d415bd75Srobert }
222*d415bd75Srobert 
getDXILOpFunctionType(const OpCodeProperty * Prop,Type * OverloadTy)223*d415bd75Srobert static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
224*d415bd75Srobert                                            Type *OverloadTy) {
225*d415bd75Srobert   SmallVector<Type *> ArgTys;
226*d415bd75Srobert 
227*d415bd75Srobert   auto ParamKinds = getOpCodeParameterKind(*Prop);
228*d415bd75Srobert 
229*d415bd75Srobert   for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
230*d415bd75Srobert     ParameterKind Kind = ParamKinds[I];
231*d415bd75Srobert     ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
232*d415bd75Srobert   }
233*d415bd75Srobert   return FunctionType::get(
234*d415bd75Srobert       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
235*d415bd75Srobert }
236*d415bd75Srobert 
getOrCreateDXILOpFunction(dxil::OpCode DXILOp,Type * OverloadTy,Module & M)237*d415bd75Srobert static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
238*d415bd75Srobert                                                 Type *OverloadTy, Module &M) {
239*d415bd75Srobert   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
240*d415bd75Srobert 
241*d415bd75Srobert   OverloadKind Kind = getOverloadKind(OverloadTy);
242*d415bd75Srobert   // FIXME: find the issue and report error in clang instead of check it in
243*d415bd75Srobert   // backend.
244*d415bd75Srobert   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
245*d415bd75Srobert     llvm_unreachable("invalid overload");
246*d415bd75Srobert   }
247*d415bd75Srobert 
248*d415bd75Srobert   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
249*d415bd75Srobert   // Dependent on name to dedup.
250*d415bd75Srobert   if (auto *Fn = M.getFunction(FnName))
251*d415bd75Srobert     return FunctionCallee(Fn);
252*d415bd75Srobert 
253*d415bd75Srobert   FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
254*d415bd75Srobert   return M.getOrInsertFunction(FnName, DXILOpFT);
255*d415bd75Srobert }
256*d415bd75Srobert 
257*d415bd75Srobert namespace llvm {
258*d415bd75Srobert namespace dxil {
259*d415bd75Srobert 
createDXILOpCall(dxil::OpCode OpCode,Type * OverloadTy,llvm::iterator_range<Use * > Args)260*d415bd75Srobert CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
261*d415bd75Srobert                                           llvm::iterator_range<Use *> Args) {
262*d415bd75Srobert   auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
263*d415bd75Srobert   SmallVector<Value *> FullArgs;
264*d415bd75Srobert   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
265*d415bd75Srobert   FullArgs.append(Args.begin(), Args.end());
266*d415bd75Srobert   return B.CreateCall(Fn, FullArgs);
267*d415bd75Srobert }
268*d415bd75Srobert 
getOverloadTy(dxil::OpCode OpCode,FunctionType * FT,bool NoOpCodeParam)269*d415bd75Srobert Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
270*d415bd75Srobert                                    bool NoOpCodeParam) {
271*d415bd75Srobert 
272*d415bd75Srobert   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
273*d415bd75Srobert   if (Prop->OverloadParamIndex < 0) {
274*d415bd75Srobert     auto &Ctx = FT->getContext();
275*d415bd75Srobert     // When only has 1 overload type, just return it.
276*d415bd75Srobert     switch (Prop->OverloadTys) {
277*d415bd75Srobert     case OverloadKind::VOID:
278*d415bd75Srobert       return Type::getVoidTy(Ctx);
279*d415bd75Srobert     case OverloadKind::HALF:
280*d415bd75Srobert       return Type::getHalfTy(Ctx);
281*d415bd75Srobert     case OverloadKind::FLOAT:
282*d415bd75Srobert       return Type::getFloatTy(Ctx);
283*d415bd75Srobert     case OverloadKind::DOUBLE:
284*d415bd75Srobert       return Type::getDoubleTy(Ctx);
285*d415bd75Srobert     case OverloadKind::I1:
286*d415bd75Srobert       return Type::getInt1Ty(Ctx);
287*d415bd75Srobert     case OverloadKind::I8:
288*d415bd75Srobert       return Type::getInt8Ty(Ctx);
289*d415bd75Srobert     case OverloadKind::I16:
290*d415bd75Srobert       return Type::getInt16Ty(Ctx);
291*d415bd75Srobert     case OverloadKind::I32:
292*d415bd75Srobert       return Type::getInt32Ty(Ctx);
293*d415bd75Srobert     case OverloadKind::I64:
294*d415bd75Srobert       return Type::getInt64Ty(Ctx);
295*d415bd75Srobert     default:
296*d415bd75Srobert       llvm_unreachable("invalid overload type");
297*d415bd75Srobert       return nullptr;
298*d415bd75Srobert     }
299*d415bd75Srobert   }
300*d415bd75Srobert 
301*d415bd75Srobert   // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
302*d415bd75Srobert   Type *OverloadType = FT->getReturnType();
303*d415bd75Srobert   if (Prop->OverloadParamIndex != 0) {
304*d415bd75Srobert     // Skip Return Type and Type for DXIL opcode.
305*d415bd75Srobert     const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
306*d415bd75Srobert     OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
307*d415bd75Srobert   }
308*d415bd75Srobert 
309*d415bd75Srobert   auto ParamKinds = getOpCodeParameterKind(*Prop);
310*d415bd75Srobert   auto Kind = ParamKinds[Prop->OverloadParamIndex];
311*d415bd75Srobert   // For ResRet and CBufferRet, OverloadTy is in field of StructType.
312*d415bd75Srobert   if (Kind == ParameterKind::CBUFFER_RET ||
313*d415bd75Srobert       Kind == ParameterKind::RESOURCE_RET) {
314*d415bd75Srobert     auto *ST = cast<StructType>(OverloadType);
315*d415bd75Srobert     OverloadType = ST->getElementType(0);
316*d415bd75Srobert   }
317*d415bd75Srobert   return OverloadType;
318*d415bd75Srobert }
319*d415bd75Srobert 
getOpCodeName(dxil::OpCode DXILOp)320*d415bd75Srobert const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
321*d415bd75Srobert   return ::getOpCodeName(DXILOp);
322*d415bd75Srobert }
323*d415bd75Srobert } // namespace dxil
324*d415bd75Srobert } // namespace llvm
325