xref: /llvm-project/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (revision 72e8b9aeaa3f584f223bc59924812df69a09a48b)
1 //===-- CompilationInterfaces.h - GPU compilation interfaces  ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines interfaces for GPU compilation.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
14 #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
15 
16 #include "mlir/IR/Attributes.h"
17 #include "llvm/IR/Module.h"
18 
19 namespace llvm {
20 class IRBuilderBase;
21 }
22 
23 namespace mlir {
24 class SymbolTable;
25 namespace LLVM {
26 class ModuleTranslation;
27 }
28 namespace gpu {
29 enum class CompilationTarget : uint32_t;
30 constexpr StringLiteral elfSectionName = "section";
31 
32 /// This class indicates that the attribute associated with this trait is a GPU
33 /// offloading translation attribute. These kinds of attributes must implement
34 /// an interface for handling the translation of GPU offloading operations like
35 /// `gpu.binary` & `gpu.launch_func`.
36 template <typename ConcreteType>
37 class OffloadingTranslationAttrTrait
38     : public AttributeTrait::TraitBase<ConcreteType,
39                                        OffloadingTranslationAttrTrait> {
40   // TODO: Verify the attribute promises or implements the interface.
41 };
42 
43 /// This class serves as an opaque interface for passing options to the
44 /// `TargetAttrInterface` methods. Users of this class must implement the
45 /// `classof` method as well as using the macros `MLIR_*_EXPLICIT_TYPE_ID` to
46 /// ensure type safeness. Targets are free to ignore these options.
47 class TargetOptions {
48 public:
49   /// Constructor initializing the toolkit path, the list of files to link to,
50   /// extra command line options, the compilation target and a callback for
51   /// obtaining the parent symbol table. The default compilation target is
52   /// `Fatbin`.
53   TargetOptions(
54       StringRef toolkitPath = {}, ArrayRef<Attribute> librariesToLink = {},
55       StringRef cmdOptions = {}, StringRef elfSection = {},
56       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
57       function_ref<SymbolTable *()> getSymbolTableCallback = {},
58       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
59       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
60       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
61       function_ref<void(StringRef)> isaCallback = {});
62 
63   /// Returns the typeID.
64   TypeID getTypeID() const;
65 
66   /// Returns the toolkit path.
67   StringRef getToolkitPath() const;
68 
69   /// Returns the LLVM libraries to link to.
70   ArrayRef<Attribute> getLibrariesToLink() const;
71 
72   /// Returns the command line options.
73   StringRef getCmdOptions() const;
74 
75   /// Returns the ELF section.
76   StringRef getELFSection() const;
77 
78   /// Returns a tokenization of the command line options.
79   std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
80   tokenizeCmdOptions() const;
81 
82   /// Returns the compilation target.
83   CompilationTarget getCompilationTarget() const;
84 
85   /// Returns the result of the `getSymbolTableCallback` callback or a nullptr
86   /// if no callback was provided.
87   /// Note: The callback itself can return nullptr. It is up to the target how
88   /// to react to getting a nullptr, e.g., emitting an error or constructing the
89   /// table.
90   SymbolTable *getSymbolTable() const;
91 
92   /// Returns the callback invoked with the initial LLVM IR for the device
93   /// module.
94   function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
95 
96   /// Returns the callback invoked with LLVM IR for the device module
97   /// after linking the device libraries.
98   function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
99 
100   /// Returns the callback invoked with LLVM IR for the device module after
101   /// LLVM optimizations but before codegen.
102   function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
103 
104   /// Returns the callback invoked with the target ISA for the device,
105   /// for example PTX assembly.
106   function_ref<void(StringRef)> getISACallback() const;
107 
108   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
109   static CompilationTarget getDefaultCompilationTarget();
110 
111 protected:
112   /// Derived classes must use this constructor to initialize `typeID` to the
113   /// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
114   TargetOptions(
115       TypeID typeID, StringRef toolkitPath = {},
116       ArrayRef<Attribute> librariesToLink = {}, StringRef cmdOptions = {},
117       StringRef elfSection = {},
118       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
119       function_ref<SymbolTable *()> getSymbolTableCallback = {},
120       function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
121       function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
122       function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
123       function_ref<void(StringRef)> isaCallback = {});
124 
125   /// Path to the target toolkit.
126   std::string toolkitPath;
127 
128   /// List of files to link with the LLVM module.
129   SmallVector<Attribute> librariesToLink;
130 
131   /// An optional set of command line options to be used by the compilation
132   /// process.
133   std::string cmdOptions;
134 
135   /// ELF Section where the binary needs to be located
136   std::string elfSection;
137 
138   /// Compilation process target format.
139   CompilationTarget compilationTarget;
140 
141   /// Callback for obtaining the parent symbol table of all the GPU modules
142   /// being serialized.
143   function_ref<SymbolTable *()> getSymbolTableCallback;
144 
145   /// Callback invoked with the initial LLVM IR for the device module.
146   function_ref<void(llvm::Module &)> initialLlvmIRCallback;
147 
148   /// Callback invoked with LLVM IR for the device module after
149   /// linking the device libraries.
150   function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
151 
152   /// Callback invoked with LLVM IR for the device module after
153   /// LLVM optimizations but before codegen.
154   function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
155 
156   /// Callback invoked with the target ISA for the device,
157   /// for example PTX assembly.
158   function_ref<void(StringRef)> isaCallback;
159 
160 private:
161   TypeID typeID;
162 };
163 } // namespace gpu
164 } // namespace mlir
165 
166 MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions)
167 
168 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.h.inc"
169 
170 #endif // MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
171