1972a253aSDimitry Andric //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// 2972a253aSDimitry Andric // 3972a253aSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4972a253aSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5972a253aSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6972a253aSDimitry Andric // 7972a253aSDimitry Andric //===----------------------------------------------------------------------===// 8972a253aSDimitry Andric /// 9972a253aSDimitry Andric /// \file This file contains class to help build DXIL op functions. 10972a253aSDimitry Andric //===----------------------------------------------------------------------===// 11972a253aSDimitry Andric 12972a253aSDimitry Andric #include "DXILOpBuilder.h" 13972a253aSDimitry Andric #include "DXILConstants.h" 14972a253aSDimitry Andric #include "llvm/IR/IRBuilder.h" 15972a253aSDimitry Andric #include "llvm/IR/Module.h" 16972a253aSDimitry Andric #include "llvm/Support/DXILOperationCommon.h" 17972a253aSDimitry Andric #include "llvm/Support/ErrorHandling.h" 18972a253aSDimitry Andric 19972a253aSDimitry Andric using namespace llvm; 20*bdd1243dSDimitry Andric using namespace llvm::dxil; 21972a253aSDimitry Andric 22972a253aSDimitry Andric constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 23972a253aSDimitry Andric 24972a253aSDimitry Andric namespace { 25972a253aSDimitry Andric 26972a253aSDimitry Andric enum OverloadKind : uint16_t { 27972a253aSDimitry Andric VOID = 1, 28972a253aSDimitry Andric HALF = 1 << 1, 29972a253aSDimitry Andric FLOAT = 1 << 2, 30972a253aSDimitry Andric DOUBLE = 1 << 3, 31972a253aSDimitry Andric I1 = 1 << 4, 32972a253aSDimitry Andric I8 = 1 << 5, 33972a253aSDimitry Andric I16 = 1 << 6, 34972a253aSDimitry Andric I32 = 1 << 7, 35972a253aSDimitry Andric I64 = 1 << 8, 36972a253aSDimitry Andric UserDefineType = 1 << 9, 37972a253aSDimitry Andric ObjectType = 1 << 10, 38972a253aSDimitry Andric }; 39972a253aSDimitry Andric 40972a253aSDimitry Andric } // namespace 41972a253aSDimitry Andric 42972a253aSDimitry Andric static const char *getOverloadTypeName(OverloadKind Kind) { 43972a253aSDimitry Andric switch (Kind) { 44972a253aSDimitry Andric case OverloadKind::HALF: 45972a253aSDimitry Andric return "f16"; 46972a253aSDimitry Andric case OverloadKind::FLOAT: 47972a253aSDimitry Andric return "f32"; 48972a253aSDimitry Andric case OverloadKind::DOUBLE: 49972a253aSDimitry Andric return "f64"; 50972a253aSDimitry Andric case OverloadKind::I1: 51972a253aSDimitry Andric return "i1"; 52972a253aSDimitry Andric case OverloadKind::I8: 53972a253aSDimitry Andric return "i8"; 54972a253aSDimitry Andric case OverloadKind::I16: 55972a253aSDimitry Andric return "i16"; 56972a253aSDimitry Andric case OverloadKind::I32: 57972a253aSDimitry Andric return "i32"; 58972a253aSDimitry Andric case OverloadKind::I64: 59972a253aSDimitry Andric return "i64"; 60972a253aSDimitry Andric case OverloadKind::VOID: 61972a253aSDimitry Andric case OverloadKind::ObjectType: 62972a253aSDimitry Andric case OverloadKind::UserDefineType: 63972a253aSDimitry Andric break; 64972a253aSDimitry Andric } 65972a253aSDimitry Andric llvm_unreachable("invalid overload type for name"); 66972a253aSDimitry Andric return "void"; 67972a253aSDimitry Andric } 68972a253aSDimitry Andric 69972a253aSDimitry Andric static OverloadKind getOverloadKind(Type *Ty) { 70972a253aSDimitry Andric Type::TypeID T = Ty->getTypeID(); 71972a253aSDimitry Andric switch (T) { 72972a253aSDimitry Andric case Type::VoidTyID: 73972a253aSDimitry Andric return OverloadKind::VOID; 74972a253aSDimitry Andric case Type::HalfTyID: 75972a253aSDimitry Andric return OverloadKind::HALF; 76972a253aSDimitry Andric case Type::FloatTyID: 77972a253aSDimitry Andric return OverloadKind::FLOAT; 78972a253aSDimitry Andric case Type::DoubleTyID: 79972a253aSDimitry Andric return OverloadKind::DOUBLE; 80972a253aSDimitry Andric case Type::IntegerTyID: { 81972a253aSDimitry Andric IntegerType *ITy = cast<IntegerType>(Ty); 82972a253aSDimitry Andric unsigned Bits = ITy->getBitWidth(); 83972a253aSDimitry Andric switch (Bits) { 84972a253aSDimitry Andric case 1: 85972a253aSDimitry Andric return OverloadKind::I1; 86972a253aSDimitry Andric case 8: 87972a253aSDimitry Andric return OverloadKind::I8; 88972a253aSDimitry Andric case 16: 89972a253aSDimitry Andric return OverloadKind::I16; 90972a253aSDimitry Andric case 32: 91972a253aSDimitry Andric return OverloadKind::I32; 92972a253aSDimitry Andric case 64: 93972a253aSDimitry Andric return OverloadKind::I64; 94972a253aSDimitry Andric default: 95972a253aSDimitry Andric llvm_unreachable("invalid overload type"); 96972a253aSDimitry Andric return OverloadKind::VOID; 97972a253aSDimitry Andric } 98972a253aSDimitry Andric } 99972a253aSDimitry Andric case Type::PointerTyID: 100972a253aSDimitry Andric return OverloadKind::UserDefineType; 101972a253aSDimitry Andric case Type::StructTyID: 102972a253aSDimitry Andric return OverloadKind::ObjectType; 103972a253aSDimitry Andric default: 104972a253aSDimitry Andric llvm_unreachable("invalid overload type"); 105972a253aSDimitry Andric return OverloadKind::VOID; 106972a253aSDimitry Andric } 107972a253aSDimitry Andric } 108972a253aSDimitry Andric 109972a253aSDimitry Andric static std::string getTypeName(OverloadKind Kind, Type *Ty) { 110972a253aSDimitry Andric if (Kind < OverloadKind::UserDefineType) { 111972a253aSDimitry Andric return getOverloadTypeName(Kind); 112972a253aSDimitry Andric } else if (Kind == OverloadKind::UserDefineType) { 113972a253aSDimitry Andric StructType *ST = cast<StructType>(Ty); 114972a253aSDimitry Andric return ST->getStructName().str(); 115972a253aSDimitry Andric } else if (Kind == OverloadKind::ObjectType) { 116972a253aSDimitry Andric StructType *ST = cast<StructType>(Ty); 117972a253aSDimitry Andric return ST->getStructName().str(); 118972a253aSDimitry Andric } else { 119972a253aSDimitry Andric std::string Str; 120972a253aSDimitry Andric raw_string_ostream OS(Str); 121972a253aSDimitry Andric Ty->print(OS); 122972a253aSDimitry Andric return OS.str(); 123972a253aSDimitry Andric } 124972a253aSDimitry Andric } 125972a253aSDimitry Andric 126972a253aSDimitry Andric // Static properties. 127972a253aSDimitry Andric struct OpCodeProperty { 128*bdd1243dSDimitry Andric dxil::OpCode OpCode; 129972a253aSDimitry Andric // Offset in DXILOpCodeNameTable. 130972a253aSDimitry Andric unsigned OpCodeNameOffset; 131*bdd1243dSDimitry Andric dxil::OpCodeClass OpCodeClass; 132972a253aSDimitry Andric // Offset in DXILOpCodeClassNameTable. 133972a253aSDimitry Andric unsigned OpCodeClassNameOffset; 134972a253aSDimitry Andric uint16_t OverloadTys; 135972a253aSDimitry Andric llvm::Attribute::AttrKind FuncAttr; 136972a253aSDimitry Andric int OverloadParamIndex; // parameter index which control the overload. 137972a253aSDimitry Andric // When < 0, should be only 1 overload type. 138972a253aSDimitry Andric unsigned NumOfParameters; // Number of parameters include return value. 139972a253aSDimitry Andric unsigned ParameterTableOffset; // Offset in ParameterTable. 140972a253aSDimitry Andric }; 141972a253aSDimitry Andric 142972a253aSDimitry Andric // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and 143972a253aSDimitry Andric // getOpCodeParameterKind which generated by tableGen. 144972a253aSDimitry Andric #define DXIL_OP_OPERATION_TABLE 145972a253aSDimitry Andric #include "DXILOperation.inc" 146972a253aSDimitry Andric #undef DXIL_OP_OPERATION_TABLE 147972a253aSDimitry Andric 148972a253aSDimitry Andric static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 149972a253aSDimitry Andric const OpCodeProperty &Prop) { 150972a253aSDimitry Andric if (Kind == OverloadKind::VOID) { 151972a253aSDimitry Andric return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 152972a253aSDimitry Andric } 153972a253aSDimitry Andric return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 154972a253aSDimitry Andric getTypeName(Kind, Ty)) 155972a253aSDimitry Andric .str(); 156972a253aSDimitry Andric } 157972a253aSDimitry Andric 158972a253aSDimitry Andric static std::string constructOverloadTypeName(OverloadKind Kind, 159972a253aSDimitry Andric StringRef TypeName) { 160972a253aSDimitry Andric if (Kind == OverloadKind::VOID) 161972a253aSDimitry Andric return TypeName.str(); 162972a253aSDimitry Andric 163972a253aSDimitry Andric assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); 164972a253aSDimitry Andric return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); 165972a253aSDimitry Andric } 166972a253aSDimitry Andric 167972a253aSDimitry Andric static StructType *getOrCreateStructType(StringRef Name, 168972a253aSDimitry Andric ArrayRef<Type *> EltTys, 169972a253aSDimitry Andric LLVMContext &Ctx) { 170972a253aSDimitry Andric StructType *ST = StructType::getTypeByName(Ctx, Name); 171972a253aSDimitry Andric if (ST) 172972a253aSDimitry Andric return ST; 173972a253aSDimitry Andric 174972a253aSDimitry Andric return StructType::create(Ctx, EltTys, Name); 175972a253aSDimitry Andric } 176972a253aSDimitry Andric 177972a253aSDimitry Andric static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { 178972a253aSDimitry Andric OverloadKind Kind = getOverloadKind(OverloadTy); 179972a253aSDimitry Andric std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); 180972a253aSDimitry Andric Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, 181972a253aSDimitry Andric Type::getInt32Ty(Ctx)}; 182972a253aSDimitry Andric return getOrCreateStructType(TypeName, FieldTypes, Ctx); 183972a253aSDimitry Andric } 184972a253aSDimitry Andric 185972a253aSDimitry Andric static StructType *getHandleType(LLVMContext &Ctx) { 186972a253aSDimitry Andric return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); 187972a253aSDimitry Andric } 188972a253aSDimitry Andric 189972a253aSDimitry Andric static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { 190972a253aSDimitry Andric auto &Ctx = OverloadTy->getContext(); 191972a253aSDimitry Andric switch (Kind) { 192972a253aSDimitry Andric case ParameterKind::VOID: 193972a253aSDimitry Andric return Type::getVoidTy(Ctx); 194972a253aSDimitry Andric case ParameterKind::HALF: 195972a253aSDimitry Andric return Type::getHalfTy(Ctx); 196972a253aSDimitry Andric case ParameterKind::FLOAT: 197972a253aSDimitry Andric return Type::getFloatTy(Ctx); 198972a253aSDimitry Andric case ParameterKind::DOUBLE: 199972a253aSDimitry Andric return Type::getDoubleTy(Ctx); 200972a253aSDimitry Andric case ParameterKind::I1: 201972a253aSDimitry Andric return Type::getInt1Ty(Ctx); 202972a253aSDimitry Andric case ParameterKind::I8: 203972a253aSDimitry Andric return Type::getInt8Ty(Ctx); 204972a253aSDimitry Andric case ParameterKind::I16: 205972a253aSDimitry Andric return Type::getInt16Ty(Ctx); 206972a253aSDimitry Andric case ParameterKind::I32: 207972a253aSDimitry Andric return Type::getInt32Ty(Ctx); 208972a253aSDimitry Andric case ParameterKind::I64: 209972a253aSDimitry Andric return Type::getInt64Ty(Ctx); 210972a253aSDimitry Andric case ParameterKind::OVERLOAD: 211972a253aSDimitry Andric return OverloadTy; 212972a253aSDimitry Andric case ParameterKind::RESOURCE_RET: 213972a253aSDimitry Andric return getResRetType(OverloadTy, Ctx); 214972a253aSDimitry Andric case ParameterKind::DXIL_HANDLE: 215972a253aSDimitry Andric return getHandleType(Ctx); 216972a253aSDimitry Andric default: 217972a253aSDimitry Andric break; 218972a253aSDimitry Andric } 219972a253aSDimitry Andric llvm_unreachable("Invalid parameter kind"); 220972a253aSDimitry Andric return nullptr; 221972a253aSDimitry Andric } 222972a253aSDimitry Andric 223972a253aSDimitry Andric static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, 224972a253aSDimitry Andric Type *OverloadTy) { 225972a253aSDimitry Andric SmallVector<Type *> ArgTys; 226972a253aSDimitry Andric 227972a253aSDimitry Andric auto ParamKinds = getOpCodeParameterKind(*Prop); 228972a253aSDimitry Andric 229972a253aSDimitry Andric for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { 230972a253aSDimitry Andric ParameterKind Kind = ParamKinds[I]; 231972a253aSDimitry Andric ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); 232972a253aSDimitry Andric } 233972a253aSDimitry Andric return FunctionType::get( 234972a253aSDimitry Andric ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); 235972a253aSDimitry Andric } 236972a253aSDimitry Andric 237*bdd1243dSDimitry Andric static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp, 238972a253aSDimitry Andric Type *OverloadTy, Module &M) { 239972a253aSDimitry Andric const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 240972a253aSDimitry Andric 241972a253aSDimitry Andric OverloadKind Kind = getOverloadKind(OverloadTy); 242972a253aSDimitry Andric // FIXME: find the issue and report error in clang instead of check it in 243972a253aSDimitry Andric // backend. 244972a253aSDimitry Andric if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 245972a253aSDimitry Andric llvm_unreachable("invalid overload"); 246972a253aSDimitry Andric } 247972a253aSDimitry Andric 248972a253aSDimitry Andric std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 249972a253aSDimitry Andric // Dependent on name to dedup. 250972a253aSDimitry Andric if (auto *Fn = M.getFunction(FnName)) 251972a253aSDimitry Andric return FunctionCallee(Fn); 252972a253aSDimitry Andric 253972a253aSDimitry Andric FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); 254972a253aSDimitry Andric return M.getOrInsertFunction(FnName, DXILOpFT); 255972a253aSDimitry Andric } 256972a253aSDimitry Andric 257972a253aSDimitry Andric namespace llvm { 258*bdd1243dSDimitry Andric namespace dxil { 259972a253aSDimitry Andric 260*bdd1243dSDimitry Andric CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, 261972a253aSDimitry Andric llvm::iterator_range<Use *> Args) { 262972a253aSDimitry Andric auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); 263972a253aSDimitry Andric SmallVector<Value *> FullArgs; 264972a253aSDimitry Andric FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); 265972a253aSDimitry Andric FullArgs.append(Args.begin(), Args.end()); 266972a253aSDimitry Andric return B.CreateCall(Fn, FullArgs); 267972a253aSDimitry Andric } 268972a253aSDimitry Andric 269*bdd1243dSDimitry Andric Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT, 270972a253aSDimitry Andric bool NoOpCodeParam) { 271972a253aSDimitry Andric 272972a253aSDimitry Andric const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 273972a253aSDimitry Andric if (Prop->OverloadParamIndex < 0) { 274972a253aSDimitry Andric auto &Ctx = FT->getContext(); 275972a253aSDimitry Andric // When only has 1 overload type, just return it. 276972a253aSDimitry Andric switch (Prop->OverloadTys) { 277972a253aSDimitry Andric case OverloadKind::VOID: 278972a253aSDimitry Andric return Type::getVoidTy(Ctx); 279972a253aSDimitry Andric case OverloadKind::HALF: 280972a253aSDimitry Andric return Type::getHalfTy(Ctx); 281972a253aSDimitry Andric case OverloadKind::FLOAT: 282972a253aSDimitry Andric return Type::getFloatTy(Ctx); 283972a253aSDimitry Andric case OverloadKind::DOUBLE: 284972a253aSDimitry Andric return Type::getDoubleTy(Ctx); 285972a253aSDimitry Andric case OverloadKind::I1: 286972a253aSDimitry Andric return Type::getInt1Ty(Ctx); 287972a253aSDimitry Andric case OverloadKind::I8: 288972a253aSDimitry Andric return Type::getInt8Ty(Ctx); 289972a253aSDimitry Andric case OverloadKind::I16: 290972a253aSDimitry Andric return Type::getInt16Ty(Ctx); 291972a253aSDimitry Andric case OverloadKind::I32: 292972a253aSDimitry Andric return Type::getInt32Ty(Ctx); 293972a253aSDimitry Andric case OverloadKind::I64: 294972a253aSDimitry Andric return Type::getInt64Ty(Ctx); 295972a253aSDimitry Andric default: 296972a253aSDimitry Andric llvm_unreachable("invalid overload type"); 297972a253aSDimitry Andric return nullptr; 298972a253aSDimitry Andric } 299972a253aSDimitry Andric } 300972a253aSDimitry Andric 301972a253aSDimitry Andric // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). 302972a253aSDimitry Andric Type *OverloadType = FT->getReturnType(); 303972a253aSDimitry Andric if (Prop->OverloadParamIndex != 0) { 304972a253aSDimitry Andric // Skip Return Type and Type for DXIL opcode. 305972a253aSDimitry Andric const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; 306972a253aSDimitry Andric OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); 307972a253aSDimitry Andric } 308972a253aSDimitry Andric 309972a253aSDimitry Andric auto ParamKinds = getOpCodeParameterKind(*Prop); 310972a253aSDimitry Andric auto Kind = ParamKinds[Prop->OverloadParamIndex]; 311972a253aSDimitry Andric // For ResRet and CBufferRet, OverloadTy is in field of StructType. 312972a253aSDimitry Andric if (Kind == ParameterKind::CBUFFER_RET || 313972a253aSDimitry Andric Kind == ParameterKind::RESOURCE_RET) { 314972a253aSDimitry Andric auto *ST = cast<StructType>(OverloadType); 315972a253aSDimitry Andric OverloadType = ST->getElementType(0); 316972a253aSDimitry Andric } 317972a253aSDimitry Andric return OverloadType; 318972a253aSDimitry Andric } 319972a253aSDimitry Andric 320*bdd1243dSDimitry Andric const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { 321972a253aSDimitry Andric return ::getOpCodeName(DXILOp); 322972a253aSDimitry Andric } 323*bdd1243dSDimitry Andric } // namespace dxil 324972a253aSDimitry Andric } // namespace llvm 325