1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/CodeGen/Passes.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/Instruction.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/IntrinsicsDirectX.h" 21 #include "llvm/IR/Module.h" 22 #include "llvm/IR/PassManager.h" 23 #include "llvm/Pass.h" 24 #include "llvm/Support/ErrorHandling.h" 25 26 #define DEBUG_TYPE "dxil-op-lower" 27 28 using namespace llvm; 29 using namespace llvm::DXIL; 30 31 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 32 33 enum OverloadKind : uint16_t { 34 VOID = 1, 35 HALF = 1 << 1, 36 FLOAT = 1 << 2, 37 DOUBLE = 1 << 3, 38 I1 = 1 << 4, 39 I8 = 1 << 5, 40 I16 = 1 << 6, 41 I32 = 1 << 7, 42 I64 = 1 << 8, 43 UserDefineType = 1 << 9, 44 ObjectType = 1 << 10, 45 }; 46 47 static const char *getOverloadTypeName(OverloadKind Kind) { 48 switch (Kind) { 49 case OverloadKind::HALF: 50 return "f16"; 51 case OverloadKind::FLOAT: 52 return "f32"; 53 case OverloadKind::DOUBLE: 54 return "f64"; 55 case OverloadKind::I1: 56 return "i1"; 57 case OverloadKind::I8: 58 return "i8"; 59 case OverloadKind::I16: 60 return "i16"; 61 case OverloadKind::I32: 62 return "i32"; 63 case OverloadKind::I64: 64 return "i64"; 65 case OverloadKind::VOID: 66 case OverloadKind::ObjectType: 67 case OverloadKind::UserDefineType: 68 break; 69 } 70 llvm_unreachable("invalid overload type for name"); 71 return "void"; 72 } 73 74 static OverloadKind getOverloadKind(Type *Ty) { 75 Type::TypeID T = Ty->getTypeID(); 76 switch (T) { 77 case Type::VoidTyID: 78 return OverloadKind::VOID; 79 case Type::HalfTyID: 80 return OverloadKind::HALF; 81 case Type::FloatTyID: 82 return OverloadKind::FLOAT; 83 case Type::DoubleTyID: 84 return OverloadKind::DOUBLE; 85 case Type::IntegerTyID: { 86 IntegerType *ITy = cast<IntegerType>(Ty); 87 unsigned Bits = ITy->getBitWidth(); 88 switch (Bits) { 89 case 1: 90 return OverloadKind::I1; 91 case 8: 92 return OverloadKind::I8; 93 case 16: 94 return OverloadKind::I16; 95 case 32: 96 return OverloadKind::I32; 97 case 64: 98 return OverloadKind::I64; 99 default: 100 llvm_unreachable("invalid overload type"); 101 return OverloadKind::VOID; 102 } 103 } 104 case Type::PointerTyID: 105 return OverloadKind::UserDefineType; 106 case Type::StructTyID: 107 return OverloadKind::ObjectType; 108 default: 109 llvm_unreachable("invalid overload type"); 110 return OverloadKind::VOID; 111 } 112 } 113 114 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 115 if (Kind < OverloadKind::UserDefineType) { 116 return getOverloadTypeName(Kind); 117 } else if (Kind == OverloadKind::UserDefineType) { 118 StructType *ST = cast<StructType>(Ty); 119 return ST->getStructName().str(); 120 } else if (Kind == OverloadKind::ObjectType) { 121 StructType *ST = cast<StructType>(Ty); 122 return ST->getStructName().str(); 123 } else { 124 std::string Str; 125 raw_string_ostream OS(Str); 126 Ty->print(OS); 127 return OS.str(); 128 } 129 } 130 131 // Static properties. 132 struct OpCodeProperty { 133 DXIL::OpCode OpCode; 134 // Offset in DXILOpCodeNameTable. 135 unsigned OpCodeNameOffset; 136 DXIL::OpCodeClass OpCodeClass; 137 // Offset in DXILOpCodeClassNameTable. 138 unsigned OpCodeClassNameOffset; 139 uint16_t OverloadTys; 140 llvm::Attribute::AttrKind FuncAttr; 141 }; 142 143 // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which 144 // generated by tableGen. 145 #define DXIL_OP_OPERATION_TABLE 146 #include "DXILOperation.inc" 147 #undef DXIL_OP_OPERATION_TABLE 148 149 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 150 const OpCodeProperty &Prop) { 151 if (Kind == OverloadKind::VOID) { 152 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 153 } 154 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 155 getTypeName(Kind, Ty)) 156 .str(); 157 } 158 159 static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, 160 Module &M) { 161 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 162 163 // Get return type as overload type for DXILOp. 164 // Only simple mapping case here, so return type is good enough. 165 Type *OverloadTy = F.getReturnType(); 166 167 OverloadKind Kind = getOverloadKind(OverloadTy); 168 // FIXME: find the issue and report error in clang instead of check it in 169 // backend. 170 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 171 llvm_unreachable("invalid overload"); 172 } 173 174 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 175 assert(!M.getFunction(FnName) && "Function already exists"); 176 177 auto &Ctx = M.getContext(); 178 Type *OpCodeTy = Type::getInt32Ty(Ctx); 179 180 SmallVector<Type *> ArgTypes; 181 // DXIL has i32 opcode as first arg. 182 ArgTypes.emplace_back(OpCodeTy); 183 FunctionType *FT = F.getFunctionType(); 184 ArgTypes.append(FT->param_begin(), FT->param_end()); 185 FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); 186 return M.getOrInsertFunction(FnName, DXILOpFT); 187 } 188 189 static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { 190 auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); 191 IRBuilder<> B(M.getContext()); 192 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 193 for (User *U : make_early_inc_range(F.users())) { 194 CallInst *CI = dyn_cast<CallInst>(U); 195 if (!CI) 196 continue; 197 198 SmallVector<Value *> Args; 199 Args.emplace_back(DXILOpArg); 200 Args.append(CI->arg_begin(), CI->arg_end()); 201 B.SetInsertPoint(CI); 202 CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); 203 LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); 204 CI->replaceAllUsesWith(DXILCI); 205 CI->eraseFromParent(); 206 } 207 if (F.user_empty()) 208 F.eraseFromParent(); 209 } 210 211 static bool lowerIntrinsics(Module &M) { 212 bool Updated = false; 213 214 #define DXIL_OP_INTRINSIC_MAP 215 #include "DXILOperation.inc" 216 #undef DXIL_OP_INTRINSIC_MAP 217 218 for (Function &F : make_early_inc_range(M.functions())) { 219 if (!F.isDeclaration()) 220 continue; 221 Intrinsic::ID ID = F.getIntrinsicID(); 222 if (ID == Intrinsic::not_intrinsic) 223 continue; 224 auto LowerIt = LowerMap.find(ID); 225 if (LowerIt == LowerMap.end()) 226 continue; 227 lowerIntrinsic(LowerIt->second, F, M); 228 Updated = true; 229 } 230 return Updated; 231 } 232 233 namespace { 234 /// A pass that transforms external global definitions into declarations. 235 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 236 public: 237 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 238 if (lowerIntrinsics(M)) 239 return PreservedAnalyses::none(); 240 return PreservedAnalyses::all(); 241 } 242 }; 243 } // namespace 244 245 namespace { 246 class DXILOpLoweringLegacy : public ModulePass { 247 public: 248 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 249 StringRef getPassName() const override { return "DXIL Op Lowering"; } 250 DXILOpLoweringLegacy() : ModulePass(ID) {} 251 252 static char ID; // Pass identification. 253 }; 254 char DXILOpLoweringLegacy::ID = 0; 255 256 } // end anonymous namespace 257 258 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 259 false, false) 260 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 261 false) 262 263 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 264 return new DXILOpLoweringLegacy(); 265 } 266