xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpLowering.cpp (revision 753f127f3ace09432b2baeffd71a308760641a62)
1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains passes and utilities to lower llvm intrinsic call
10 /// to DXILOp function call.
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILConstants.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/Instruction.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/IR/PassManager.h"
23 #include "llvm/Pass.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 #define DEBUG_TYPE "dxil-op-lower"
27 
28 using namespace llvm;
29 using namespace llvm::DXIL;
30 
31 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
32 
33 enum OverloadKind : uint16_t {
34   VOID = 1,
35   HALF = 1 << 1,
36   FLOAT = 1 << 2,
37   DOUBLE = 1 << 3,
38   I1 = 1 << 4,
39   I8 = 1 << 5,
40   I16 = 1 << 6,
41   I32 = 1 << 7,
42   I64 = 1 << 8,
43   UserDefineType = 1 << 9,
44   ObjectType = 1 << 10,
45 };
46 
47 static const char *getOverloadTypeName(OverloadKind Kind) {
48   switch (Kind) {
49   case OverloadKind::HALF:
50     return "f16";
51   case OverloadKind::FLOAT:
52     return "f32";
53   case OverloadKind::DOUBLE:
54     return "f64";
55   case OverloadKind::I1:
56     return "i1";
57   case OverloadKind::I8:
58     return "i8";
59   case OverloadKind::I16:
60     return "i16";
61   case OverloadKind::I32:
62     return "i32";
63   case OverloadKind::I64:
64     return "i64";
65   case OverloadKind::VOID:
66   case OverloadKind::ObjectType:
67   case OverloadKind::UserDefineType:
68     break;
69   }
70   llvm_unreachable("invalid overload type for name");
71   return "void";
72 }
73 
74 static OverloadKind getOverloadKind(Type *Ty) {
75   Type::TypeID T = Ty->getTypeID();
76   switch (T) {
77   case Type::VoidTyID:
78     return OverloadKind::VOID;
79   case Type::HalfTyID:
80     return OverloadKind::HALF;
81   case Type::FloatTyID:
82     return OverloadKind::FLOAT;
83   case Type::DoubleTyID:
84     return OverloadKind::DOUBLE;
85   case Type::IntegerTyID: {
86     IntegerType *ITy = cast<IntegerType>(Ty);
87     unsigned Bits = ITy->getBitWidth();
88     switch (Bits) {
89     case 1:
90       return OverloadKind::I1;
91     case 8:
92       return OverloadKind::I8;
93     case 16:
94       return OverloadKind::I16;
95     case 32:
96       return OverloadKind::I32;
97     case 64:
98       return OverloadKind::I64;
99     default:
100       llvm_unreachable("invalid overload type");
101       return OverloadKind::VOID;
102     }
103   }
104   case Type::PointerTyID:
105     return OverloadKind::UserDefineType;
106   case Type::StructTyID:
107     return OverloadKind::ObjectType;
108   default:
109     llvm_unreachable("invalid overload type");
110     return OverloadKind::VOID;
111   }
112 }
113 
114 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
115   if (Kind < OverloadKind::UserDefineType) {
116     return getOverloadTypeName(Kind);
117   } else if (Kind == OverloadKind::UserDefineType) {
118     StructType *ST = cast<StructType>(Ty);
119     return ST->getStructName().str();
120   } else if (Kind == OverloadKind::ObjectType) {
121     StructType *ST = cast<StructType>(Ty);
122     return ST->getStructName().str();
123   } else {
124     std::string Str;
125     raw_string_ostream OS(Str);
126     Ty->print(OS);
127     return OS.str();
128   }
129 }
130 
131 // Static properties.
132 struct OpCodeProperty {
133   DXIL::OpCode OpCode;
134   // Offset in DXILOpCodeNameTable.
135   unsigned OpCodeNameOffset;
136   DXIL::OpCodeClass OpCodeClass;
137   // Offset in DXILOpCodeClassNameTable.
138   unsigned OpCodeClassNameOffset;
139   uint16_t OverloadTys;
140   llvm::Attribute::AttrKind FuncAttr;
141 };
142 
143 // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which
144 // generated by tableGen.
145 #define DXIL_OP_OPERATION_TABLE
146 #include "DXILOperation.inc"
147 #undef DXIL_OP_OPERATION_TABLE
148 
149 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
150                                          const OpCodeProperty &Prop) {
151   if (Kind == OverloadKind::VOID) {
152     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
153   }
154   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
155           getTypeName(Kind, Ty))
156       .str();
157 }
158 
159 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
160                                            Module &M) {
161   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
162 
163   // Get return type as overload type for DXILOp.
164   // Only simple mapping case here, so return type is good enough.
165   Type *OverloadTy = F.getReturnType();
166 
167   OverloadKind Kind = getOverloadKind(OverloadTy);
168   // FIXME: find the issue and report error in clang instead of check it in
169   // backend.
170   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
171     llvm_unreachable("invalid overload");
172   }
173 
174   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
175   assert(!M.getFunction(FnName) && "Function already exists");
176 
177   auto &Ctx = M.getContext();
178   Type *OpCodeTy = Type::getInt32Ty(Ctx);
179 
180   SmallVector<Type *> ArgTypes;
181   // DXIL has i32 opcode as first arg.
182   ArgTypes.emplace_back(OpCodeTy);
183   FunctionType *FT = F.getFunctionType();
184   ArgTypes.append(FT->param_begin(), FT->param_end());
185   FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
186   return M.getOrInsertFunction(FnName, DXILOpFT);
187 }
188 
189 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
190   auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
191   IRBuilder<> B(M.getContext());
192   Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
193   for (User *U : make_early_inc_range(F.users())) {
194     CallInst *CI = dyn_cast<CallInst>(U);
195     if (!CI)
196       continue;
197 
198     SmallVector<Value *> Args;
199     Args.emplace_back(DXILOpArg);
200     Args.append(CI->arg_begin(), CI->arg_end());
201     B.SetInsertPoint(CI);
202     CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
203     LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp)));
204     CI->replaceAllUsesWith(DXILCI);
205     CI->eraseFromParent();
206   }
207   if (F.user_empty())
208     F.eraseFromParent();
209 }
210 
211 static bool lowerIntrinsics(Module &M) {
212   bool Updated = false;
213 
214 #define DXIL_OP_INTRINSIC_MAP
215 #include "DXILOperation.inc"
216 #undef DXIL_OP_INTRINSIC_MAP
217 
218   for (Function &F : make_early_inc_range(M.functions())) {
219     if (!F.isDeclaration())
220       continue;
221     Intrinsic::ID ID = F.getIntrinsicID();
222     if (ID == Intrinsic::not_intrinsic)
223       continue;
224     auto LowerIt = LowerMap.find(ID);
225     if (LowerIt == LowerMap.end())
226       continue;
227     lowerIntrinsic(LowerIt->second, F, M);
228     Updated = true;
229   }
230   return Updated;
231 }
232 
233 namespace {
234 /// A pass that transforms external global definitions into declarations.
235 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
236 public:
237   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
238     if (lowerIntrinsics(M))
239       return PreservedAnalyses::none();
240     return PreservedAnalyses::all();
241   }
242 };
243 } // namespace
244 
245 namespace {
246 class DXILOpLoweringLegacy : public ModulePass {
247 public:
248   bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
249   StringRef getPassName() const override { return "DXIL Op Lowering"; }
250   DXILOpLoweringLegacy() : ModulePass(ID) {}
251 
252   static char ID; // Pass identification.
253 };
254 char DXILOpLoweringLegacy::ID = 0;
255 
256 } // end anonymous namespace
257 
258 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
259                       false, false)
260 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
261                     false)
262 
263 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
264   return new DXILOpLoweringLegacy();
265 }
266