xref: /llvm-project/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (revision 40afff7bd95090a75bc68a0d26b8017cc0ae65c1)
1 //===- ModuleTranslation.h - MLIR to LLVM conversion ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
15 #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
16 
17 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Target/LLVMIR/Export.h"
22 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
23 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
24 
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
27 
28 namespace llvm {
29 class BasicBlock;
30 class IRBuilderBase;
31 class Function;
32 class Value;
33 } // namespace llvm
34 
35 namespace mlir {
36 class Attribute;
37 class Block;
38 class Location;
39 
40 namespace LLVM {
41 
42 namespace detail {
43 class DebugTranslation;
44 class LoopAnnotationTranslation;
45 } // namespace detail
46 
47 class AliasScopeAttr;
48 class AliasScopeDomainAttr;
49 class DINodeAttr;
50 class LLVMFuncOp;
51 class ComdatSelectorOp;
52 
53 /// Implementation class for module translation. Holds a reference to the module
54 /// being translated, and the mappings between the original and the translated
55 /// functions, basic blocks and values. It is practically easier to hold these
56 /// mappings in one class since the conversion of control flow operations
57 /// needs to look up block and function mappings.
58 class ModuleTranslation {
59   friend std::unique_ptr<llvm::Module>
60   mlir::translateModuleToLLVMIR(Operation *, llvm::LLVMContext &, StringRef,
61                                 bool);
62 
63 public:
64   /// Stores the mapping between a function name and its LLVM IR representation.
65   void mapFunction(StringRef name, llvm::Function *func) {
66     auto result = functionMapping.try_emplace(name, func);
67     (void)result;
68     assert(result.second &&
69            "attempting to map a function that is already mapped");
70   }
71 
72   /// Finds an LLVM IR function by its name.
73   llvm::Function *lookupFunction(StringRef name) const {
74     return functionMapping.lookup(name);
75   }
76 
77   /// Stores the mapping between an MLIR value and its LLVM IR counterpart.
78   void mapValue(Value mlir, llvm::Value *llvm) { mapValue(mlir) = llvm; }
79 
80   /// Provides write-once access to store the LLVM IR value corresponding to the
81   /// given MLIR value.
82   llvm::Value *&mapValue(Value value) {
83     llvm::Value *&llvm = valueMapping[value];
84     assert(llvm == nullptr &&
85            "attempting to map a value that is already mapped");
86     return llvm;
87   }
88 
89   /// Finds an LLVM IR value corresponding to the given MLIR value.
90   llvm::Value *lookupValue(Value value) const {
91     return valueMapping.lookup(value);
92   }
93 
94   /// Looks up remapped a list of remapped values.
95   SmallVector<llvm::Value *> lookupValues(ValueRange values);
96 
97   /// Stores the mapping between an MLIR block and LLVM IR basic block.
98   void mapBlock(Block *mlir, llvm::BasicBlock *llvm) {
99     auto result = blockMapping.try_emplace(mlir, llvm);
100     (void)result;
101     assert(result.second && "attempting to map a block that is already mapped");
102   }
103 
104   /// Finds an LLVM IR basic block that corresponds to the given MLIR block.
105   llvm::BasicBlock *lookupBlock(Block *block) const {
106     return blockMapping.lookup(block);
107   }
108 
109   /// Stores the mapping between an MLIR operation with successors and a
110   /// corresponding LLVM IR instruction.
111   void mapBranch(Operation *mlir, llvm::Instruction *llvm) {
112     auto result = branchMapping.try_emplace(mlir, llvm);
113     (void)result;
114     assert(result.second &&
115            "attempting to map a branch that is already mapped");
116   }
117 
118   /// Finds an LLVM IR instruction that corresponds to the given MLIR operation
119   /// with successors.
120   llvm::Instruction *lookupBranch(Operation *op) const {
121     return branchMapping.lookup(op);
122   }
123 
124   /// Stores a mapping between an MLIR call operation and a corresponding LLVM
125   /// call instruction.
126   void mapCall(Operation *mlir, llvm::CallInst *llvm) {
127     auto result = callMapping.try_emplace(mlir, llvm);
128     (void)result;
129     assert(result.second && "attempting to map a call that is already mapped");
130   }
131 
132   /// Finds an LLVM call instruction that corresponds to the given MLIR call
133   /// operation.
134   llvm::CallInst *lookupCall(Operation *op) const {
135     return callMapping.lookup(op);
136   }
137 
138   /// Removes the mapping for blocks contained in the region and values defined
139   /// in these blocks.
140   void forgetMapping(Region &region);
141 
142   /// Returns the LLVM metadata corresponding to a mlir LLVM dialect alias scope
143   /// attribute. Creates the metadata node if it has not been converted before.
144   llvm::MDNode *getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr);
145 
146   /// Returns the LLVM metadata corresponding to an array of mlir LLVM dialect
147   /// alias scope attributes. Creates the metadata nodes if they have not been
148   /// converted before.
149   llvm::MDNode *
150   getOrCreateAliasScopes(ArrayRef<AliasScopeAttr> aliasScopeAttrs);
151 
152   // Sets LLVM metadata for memory operations that are in a parallel loop.
153   void setAccessGroupsMetadata(AccessGroupOpInterface op,
154                                llvm::Instruction *inst);
155 
156   // Sets LLVM metadata for memory operations that have alias scope information.
157   void setAliasScopeMetadata(AliasAnalysisOpInterface op,
158                              llvm::Instruction *inst);
159 
160   /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
161   void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
162 
163   /// Sets LLVM profiling metadata for operations that have branch weights.
164   void setBranchWeightsMetadata(BranchWeightOpInterface op);
165 
166   /// Sets LLVM loop metadata for branch operations that have a loop annotation
167   /// attribute.
168   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
169 
170   /// Sets the disjoint flag attribute for the exported instruction `value`
171   /// given the original operation `op`. Asserts if the operation does
172   /// not implement the disjoint flag interface, and asserts if the value
173   /// is an instruction that implements the disjoint flag.
174   void setDisjointFlag(Operation *op, llvm::Value *value);
175 
176   /// Converts the type from MLIR LLVM dialect to LLVM.
177   llvm::Type *convertType(Type type);
178 
179   /// Returns the MLIR context of the module being translated.
180   MLIRContext &getContext() { return *mlirModule->getContext(); }
181 
182   /// Returns the LLVM context in which the IR is being constructed.
183   llvm::LLVMContext &getLLVMContext() const { return llvmModule->getContext(); }
184 
185   /// Finds an LLVM IR global value that corresponds to the given MLIR operation
186   /// defining a global value.
187   llvm::GlobalValue *lookupGlobal(Operation *op) {
188     return globalsMapping.lookup(op);
189   }
190 
191   /// Returns the OpenMP IR builder associated with the LLVM IR module being
192   /// constructed.
193   llvm::OpenMPIRBuilder *getOpenMPBuilder();
194 
195   /// Returns the LLVM module in which the IR is being constructed.
196   llvm::Module *getLLVMModule() { return llvmModule.get(); }
197 
198   /// Translates the given location.
199   llvm::DILocation *translateLoc(Location loc, llvm::DILocalScope *scope);
200 
201   /// Translates the given LLVM DWARF expression metadata.
202   llvm::DIExpression *translateExpression(LLVM::DIExpressionAttr attr);
203 
204   /// Translates the given LLVM global variable expression metadata.
205   llvm::DIGlobalVariableExpression *
206   translateGlobalVariableExpression(LLVM::DIGlobalVariableExpressionAttr attr);
207 
208   /// Translates the given LLVM debug info metadata.
209   llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);
210 
211   /// Translates the given LLVM rounding mode metadata.
212   llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);
213 
214   /// Translates the given LLVM FP exception behavior metadata.
215   llvm::fp::ExceptionBehavior
216   translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);
217 
218   /// Translates the contents of the given block to LLVM IR using this
219   /// translator. The LLVM IR basic block corresponding to the given block is
220   /// expected to exist in the mapping of this translator. Uses `builder` to
221   /// translate the IR, leaving it at the end of the block. If `ignoreArguments`
222   /// is set, does not produce PHI nodes for the block arguments. Otherwise, the
223   /// PHI nodes are constructed for block arguments but are _not_ connected to
224   /// the predecessors that may not exist yet.
225   LogicalResult convertBlock(Block &bb, bool ignoreArguments,
226                              llvm::IRBuilderBase &builder) {
227     return convertBlockImpl(bb, ignoreArguments, builder,
228                             /*recordInsertions=*/false);
229   }
230 
231   /// Gets the named metadata in the LLVM IR module being constructed, creating
232   /// it if it does not exist.
233   llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name);
234 
235   /// Common CRTP base class for ModuleTranslation stack frames.
236   class StackFrame {
237   public:
238     virtual ~StackFrame() = default;
239     TypeID getTypeID() const { return typeID; }
240 
241   protected:
242     explicit StackFrame(TypeID typeID) : typeID(typeID) {}
243 
244   private:
245     const TypeID typeID;
246     virtual void anchor();
247   };
248 
249   /// Concrete CRTP base class for ModuleTranslation stack frames. When
250   /// translating operations with regions, users of ModuleTranslation can store
251   /// state on ModuleTranslation stack before entering the region and inspect
252   /// it when converting operations nested within that region. Users are
253   /// expected to derive this class and put any relevant information into fields
254   /// of the derived class. The usual isa/dyn_cast functionality is available
255   /// for instances of derived classes.
256   template <typename Derived>
257   class StackFrameBase : public StackFrame {
258   public:
259     explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {}
260   };
261 
262   /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must
263   /// be derived from `StackFrameBase<T>` and constructible from the provided
264   /// arguments. Doing this before entering the region of the op being
265   /// translated makes the frame available when translating ops within that
266   /// region.
267   template <typename T, typename... Args>
268   void stackPush(Args &&...args) {
269     static_assert(
270         std::is_base_of<StackFrame, T>::value,
271         "can only push instances of StackFrame on ModuleTranslation stack");
272     stack.push_back(std::make_unique<T>(std::forward<Args>(args)...));
273   }
274 
275   /// Pops the last element from the ModuleTranslation stack.
276   void stackPop() { stack.pop_back(); }
277 
278   /// Calls `callback` for every ModuleTranslation stack frame of type `T`
279   /// starting from the top of the stack.
280   template <typename T>
281   WalkResult
282   stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
283     static_assert(std::is_base_of<StackFrame, T>::value,
284                   "expected T derived from StackFrame");
285     if (!callback)
286       return WalkResult::skip();
287     for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
288       if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
289         WalkResult result = callback(*ptr);
290         if (result.wasInterrupted())
291           return result;
292       }
293     }
294     return WalkResult::advance();
295   }
296 
297   /// RAII object calling stackPush/stackPop on construction/destruction.
298   template <typename T>
299   struct SaveStack {
300     template <typename... Args>
301     explicit SaveStack(ModuleTranslation &m, Args &&...args)
302         : moduleTranslation(m) {
303       moduleTranslation.stackPush<T>(std::forward<Args>(args)...);
304     }
305     ~SaveStack() { moduleTranslation.stackPop(); }
306 
307   private:
308     ModuleTranslation &moduleTranslation;
309   };
310 
311   SymbolTableCollection &symbolTable() { return symbolTableCollection; }
312 
313 private:
314   ModuleTranslation(Operation *module,
315                     std::unique_ptr<llvm::Module> llvmModule);
316   ~ModuleTranslation();
317 
318   /// Converts individual components.
319   LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder,
320                                  bool recordInsertions = false);
321   LogicalResult convertFunctionSignatures();
322   LogicalResult convertFunctions();
323   LogicalResult convertComdats();
324   LogicalResult convertGlobals();
325   LogicalResult convertOneFunction(LLVMFuncOp func);
326   LogicalResult convertBlockImpl(Block &bb, bool ignoreArguments,
327                                  llvm::IRBuilderBase &builder,
328                                  bool recordInsertions);
329 
330   /// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
331   /// TBAATagAttr.
332   llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
333 
334   /// Process tbaa LLVM Metadata operations and create LLVM
335   /// metadata nodes for them.
336   LogicalResult createTBAAMetadata();
337 
338   /// Process the ident LLVM Metadata, if it exists.
339   LogicalResult createIdentMetadata();
340 
341   /// Process the llvm.commandline LLVM Metadata, if it exists.
342   LogicalResult createCommandlineMetadata();
343 
344   /// Translates dialect attributes attached to the given operation.
345   LogicalResult
346   convertDialectAttributes(Operation *op,
347                            ArrayRef<llvm::Instruction *> instructions);
348 
349   /// Translates parameter attributes and adds them to the returned AttrBuilder.
350   /// Returns failure if any of the translations failed.
351   FailureOr<llvm::AttrBuilder>
352   convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
353 
354   /// Original and translated module.
355   Operation *mlirModule;
356   std::unique_ptr<llvm::Module> llvmModule;
357   /// A converter for translating debug information.
358   std::unique_ptr<detail::DebugTranslation> debugTranslation;
359 
360   /// A converter for translating loop annotations.
361   std::unique_ptr<detail::LoopAnnotationTranslation> loopAnnotationTranslation;
362 
363   /// Builder for LLVM IR generation of OpenMP constructs.
364   std::unique_ptr<llvm::OpenMPIRBuilder> ompBuilder;
365 
366   /// Mappings between llvm.mlir.global definitions and corresponding globals.
367   DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
368 
369   /// A stateful object used to translate types.
370   TypeToLLVMIRTranslator typeTranslator;
371 
372   /// A dialect interface collection used for dispatching the translation to
373   /// specific dialects.
374   LLVMTranslationInterface iface;
375 
376   /// Mappings between original and translated values, used for lookups.
377   llvm::StringMap<llvm::Function *> functionMapping;
378   DenseMap<Value, llvm::Value *> valueMapping;
379   DenseMap<Block *, llvm::BasicBlock *> blockMapping;
380 
381   /// A mapping between MLIR LLVM dialect terminators and LLVM IR terminators
382   /// they are converted to. This allows for connecting PHI nodes to the source
383   /// values after all operations are converted.
384   DenseMap<Operation *, llvm::Instruction *> branchMapping;
385 
386   /// A mapping between MLIR LLVM dialect call operations and LLVM IR call
387   /// instructions. This allows for adding branch weights after the operations
388   /// have been converted.
389   DenseMap<Operation *, llvm::CallInst *> callMapping;
390 
391   /// Mapping from an alias scope attribute to its LLVM metadata.
392   /// This map is populated lazily.
393   DenseMap<AliasScopeAttr, llvm::MDNode *> aliasScopeMetadataMapping;
394 
395   /// Mapping from an alias scope domain attribute to its LLVM metadata.
396   /// This map is populated lazily.
397   DenseMap<AliasScopeDomainAttr, llvm::MDNode *> aliasDomainMetadataMapping;
398 
399   /// Mapping from a tbaa attribute to its LLVM metadata.
400   /// This map is populated on module entry.
401   DenseMap<Attribute, llvm::MDNode *> tbaaMetadataMapping;
402 
403   /// Mapping from a comdat selector operation to its LLVM comdat struct.
404   /// This map is populated on module entry.
405   DenseMap<ComdatSelectorOp, llvm::Comdat *> comdatMapping;
406 
407   /// Stack of user-specified state elements, useful when translating operations
408   /// with regions.
409   SmallVector<std::unique_ptr<StackFrame>> stack;
410 
411   /// A cache for the symbol tables constructed during symbols lookup.
412   SymbolTableCollection symbolTableCollection;
413 };
414 
415 namespace detail {
416 /// For all blocks in the region that were converted to LLVM IR using the given
417 /// ModuleTranslation, connect the PHI nodes of the corresponding LLVM IR blocks
418 /// to the results of preceding blocks.
419 void connectPHINodes(Region &region, const ModuleTranslation &state);
420 
421 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
422 /// This currently supports integer, floating point, splat and dense element
423 /// attributes and combinations thereof. Also, an array attribute with two
424 /// elements is supported to represent a complex constant.  In case of error,
425 /// report it to `loc` and return nullptr.
426 llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
427                                 Location loc,
428                                 const ModuleTranslation &moduleTranslation);
429 
430 /// Creates a call to an LLVM IR intrinsic function with the given arguments.
431 llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder,
432                                     llvm::Intrinsic::ID intrinsic,
433                                     ArrayRef<llvm::Value *> args = {},
434                                     ArrayRef<llvm::Type *> tys = {});
435 
436 /// Creates a call to a LLVM IR intrinsic defined by LLVM_IntrOpBase. This
437 /// resolves the overloads, and maps mixed MLIR value and attribute arguments to
438 /// LLVM values.
439 llvm::CallInst *createIntrinsicCall(
440     llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
441     Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
442     ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
443     ArrayRef<unsigned> immArgPositions,
444     ArrayRef<StringLiteral> immArgAttrNames);
445 
446 } // namespace detail
447 
448 } // namespace LLVM
449 } // namespace mlir
450 
451 namespace llvm {
452 template <typename T>
453 struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> {
454   static inline bool
455   doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) {
456     return frame.getTypeID() == ::mlir::TypeID::get<T>();
457   }
458 };
459 } // namespace llvm
460 
461 #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
462