1 //===-- CompilationInterfaces.h - GPU compilation interfaces ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines interfaces for GPU compilation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H 14 #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H 15 16 #include "mlir/IR/Attributes.h" 17 #include "llvm/IR/Module.h" 18 19 namespace llvm { 20 class IRBuilderBase; 21 } 22 23 namespace mlir { 24 class SymbolTable; 25 namespace LLVM { 26 class ModuleTranslation; 27 } 28 namespace gpu { 29 enum class CompilationTarget : uint32_t; 30 constexpr StringLiteral elfSectionName = "section"; 31 32 /// This class indicates that the attribute associated with this trait is a GPU 33 /// offloading translation attribute. These kinds of attributes must implement 34 /// an interface for handling the translation of GPU offloading operations like 35 /// `gpu.binary` & `gpu.launch_func`. 36 template <typename ConcreteType> 37 class OffloadingTranslationAttrTrait 38 : public AttributeTrait::TraitBase<ConcreteType, 39 OffloadingTranslationAttrTrait> { 40 // TODO: Verify the attribute promises or implements the interface. 41 }; 42 43 /// This class serves as an opaque interface for passing options to the 44 /// `TargetAttrInterface` methods. Users of this class must implement the 45 /// `classof` method as well as using the macros `MLIR_*_EXPLICIT_TYPE_ID` to 46 /// ensure type safeness. Targets are free to ignore these options. 47 class TargetOptions { 48 public: 49 /// Constructor initializing the toolkit path, the list of files to link to, 50 /// extra command line options, the compilation target and a callback for 51 /// obtaining the parent symbol table. The default compilation target is 52 /// `Fatbin`. 53 TargetOptions( 54 StringRef toolkitPath = {}, ArrayRef<Attribute> librariesToLink = {}, 55 StringRef cmdOptions = {}, StringRef elfSection = {}, 56 CompilationTarget compilationTarget = getDefaultCompilationTarget(), 57 function_ref<SymbolTable *()> getSymbolTableCallback = {}, 58 function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, 59 function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, 60 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, 61 function_ref<void(StringRef)> isaCallback = {}); 62 63 /// Returns the typeID. 64 TypeID getTypeID() const; 65 66 /// Returns the toolkit path. 67 StringRef getToolkitPath() const; 68 69 /// Returns the LLVM libraries to link to. 70 ArrayRef<Attribute> getLibrariesToLink() const; 71 72 /// Returns the command line options. 73 StringRef getCmdOptions() const; 74 75 /// Returns the ELF section. 76 StringRef getELFSection() const; 77 78 /// Returns a tokenization of the command line options. 79 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> 80 tokenizeCmdOptions() const; 81 82 /// Returns the compilation target. 83 CompilationTarget getCompilationTarget() const; 84 85 /// Returns the result of the `getSymbolTableCallback` callback or a nullptr 86 /// if no callback was provided. 87 /// Note: The callback itself can return nullptr. It is up to the target how 88 /// to react to getting a nullptr, e.g., emitting an error or constructing the 89 /// table. 90 SymbolTable *getSymbolTable() const; 91 92 /// Returns the callback invoked with the initial LLVM IR for the device 93 /// module. 94 function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const; 95 96 /// Returns the callback invoked with LLVM IR for the device module 97 /// after linking the device libraries. 98 function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const; 99 100 /// Returns the callback invoked with LLVM IR for the device module after 101 /// LLVM optimizations but before codegen. 102 function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const; 103 104 /// Returns the callback invoked with the target ISA for the device, 105 /// for example PTX assembly. 106 function_ref<void(StringRef)> getISACallback() const; 107 108 /// Returns the default compilation target: `CompilationTarget::Fatbin`. 109 static CompilationTarget getDefaultCompilationTarget(); 110 111 protected: 112 /// Derived classes must use this constructor to initialize `typeID` to the 113 /// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`. 114 TargetOptions( 115 TypeID typeID, StringRef toolkitPath = {}, 116 ArrayRef<Attribute> librariesToLink = {}, StringRef cmdOptions = {}, 117 StringRef elfSection = {}, 118 CompilationTarget compilationTarget = getDefaultCompilationTarget(), 119 function_ref<SymbolTable *()> getSymbolTableCallback = {}, 120 function_ref<void(llvm::Module &)> initialLlvmIRCallback = {}, 121 function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {}, 122 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {}, 123 function_ref<void(StringRef)> isaCallback = {}); 124 125 /// Path to the target toolkit. 126 std::string toolkitPath; 127 128 /// List of files to link with the LLVM module. 129 SmallVector<Attribute> librariesToLink; 130 131 /// An optional set of command line options to be used by the compilation 132 /// process. 133 std::string cmdOptions; 134 135 /// ELF Section where the binary needs to be located 136 std::string elfSection; 137 138 /// Compilation process target format. 139 CompilationTarget compilationTarget; 140 141 /// Callback for obtaining the parent symbol table of all the GPU modules 142 /// being serialized. 143 function_ref<SymbolTable *()> getSymbolTableCallback; 144 145 /// Callback invoked with the initial LLVM IR for the device module. 146 function_ref<void(llvm::Module &)> initialLlvmIRCallback; 147 148 /// Callback invoked with LLVM IR for the device module after 149 /// linking the device libraries. 150 function_ref<void(llvm::Module &)> linkedLlvmIRCallback; 151 152 /// Callback invoked with LLVM IR for the device module after 153 /// LLVM optimizations but before codegen. 154 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback; 155 156 /// Callback invoked with the target ISA for the device, 157 /// for example PTX assembly. 158 function_ref<void(StringRef)> isaCallback; 159 160 private: 161 TypeID typeID; 162 }; 163 } // namespace gpu 164 } // namespace mlir 165 166 MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions) 167 168 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.h.inc" 169 170 #endif // MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H 171