xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp (revision 1dfb104eac73863b06751bea225ffa6ef589577f)
1 //===- LLVMIRToLLVMTranslation.cpp - Translate LLVM IR to LLVM dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleImport.h"
18 
19 #include "llvm/ADT/PostOrderIterator.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/InlineAsm.h"
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/IntrinsicInst.h"
27 #include "llvm/Support/ModRef.h"
28 
29 using namespace mlir;
30 using namespace mlir::LLVM;
31 using namespace mlir::LLVM::detail;
32 
33 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
34 
35 static constexpr StringLiteral vecTypeHintMDName = "vec_type_hint";
36 static constexpr StringLiteral workGroupSizeHintMDName = "work_group_size_hint";
37 static constexpr StringLiteral reqdWorkGroupSizeMDName = "reqd_work_group_size";
38 static constexpr StringLiteral intelReqdSubGroupSizeMDName =
39     "intel_reqd_sub_group_size";
40 
41 /// Returns true if the LLVM IR intrinsic is convertible to an MLIR LLVM dialect
42 /// intrinsic. Returns false otherwise.
43 static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
44   static const DenseSet<unsigned> convertibleIntrinsics = {
45 #include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
46   };
47   return convertibleIntrinsics.contains(id);
48 }
49 
50 /// Returns the list of LLVM IR intrinsic identifiers that are convertible to
51 /// MLIR LLVM dialect intrinsics.
52 static ArrayRef<unsigned> getSupportedIntrinsicsImpl() {
53   static const SmallVector<unsigned> convertibleIntrinsics = {
54 #include "mlir/Dialect/LLVMIR/LLVMConvertibleLLVMIRIntrinsics.inc"
55   };
56   return convertibleIntrinsics;
57 }
58 
59 /// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
60 /// conversion exits. Returns failure otherwise.
61 static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
62                                           llvm::CallInst *inst,
63                                           LLVM::ModuleImport &moduleImport) {
64   llvm::Intrinsic::ID intrinsicID = inst->getIntrinsicID();
65 
66   // Check if the intrinsic is convertible to an MLIR dialect counterpart and
67   // copy the arguments to an an LLVM operands array reference for conversion.
68   if (isConvertibleIntrinsic(intrinsicID)) {
69     SmallVector<llvm::Value *> args(inst->args());
70     ArrayRef<llvm::Value *> llvmOperands(args);
71 
72     SmallVector<llvm::OperandBundleUse> llvmOpBundles;
73     llvmOpBundles.reserve(inst->getNumOperandBundles());
74     for (unsigned i = 0; i < inst->getNumOperandBundles(); ++i)
75       llvmOpBundles.push_back(inst->getOperandBundleAt(i));
76 
77 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
78   }
79 
80   return failure();
81 }
82 
83 /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM
84 /// dialect attributes.
85 static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
86   static const SmallVector<unsigned> convertibleMetadata = {
87       llvm::LLVMContext::MD_prof,
88       llvm::LLVMContext::MD_tbaa,
89       llvm::LLVMContext::MD_access_group,
90       llvm::LLVMContext::MD_loop,
91       llvm::LLVMContext::MD_noalias,
92       llvm::LLVMContext::MD_alias_scope,
93       context.getMDKindID(vecTypeHintMDName),
94       context.getMDKindID(workGroupSizeHintMDName),
95       context.getMDKindID(reqdWorkGroupSizeMDName),
96       context.getMDKindID(intelReqdSubGroupSizeMDName)};
97   return convertibleMetadata;
98 }
99 
100 /// Converts the given profiling metadata `node` to an MLIR profiling attribute
101 /// and attaches it to the imported operation if the translation succeeds.
102 /// Returns failure otherwise.
103 static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
104                                       Operation *op,
105                                       LLVM::ModuleImport &moduleImport) {
106   // Return failure for empty metadata nodes since there is nothing to import.
107   if (!node->getNumOperands())
108     return failure();
109 
110   auto *name = dyn_cast<llvm::MDString>(node->getOperand(0));
111   if (!name)
112     return failure();
113 
114   // Handle function entry count metadata.
115   if (name->getString() == "function_entry_count") {
116 
117     // TODO support function entry count metadata with GUID fields.
118     if (node->getNumOperands() != 2)
119       return failure();
120 
121     llvm::ConstantInt *entryCount =
122         llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(1));
123     if (!entryCount)
124       return failure();
125     if (auto funcOp = dyn_cast<LLVMFuncOp>(op)) {
126       funcOp.setFunctionEntryCount(entryCount->getZExtValue());
127       return success();
128     }
129     return op->emitWarning()
130            << "expected function_entry_count to be attached to a function";
131   }
132 
133   if (name->getString() != "branch_weights")
134     return failure();
135 
136   // Handle branch weights metadata.
137   SmallVector<int32_t> branchWeights;
138   branchWeights.reserve(node->getNumOperands() - 1);
139   for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
140     llvm::ConstantInt *branchWeight =
141         llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i));
142     if (!branchWeight)
143       return failure();
144     branchWeights.push_back(branchWeight->getZExtValue());
145   }
146 
147   if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
148     iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
149     return success();
150   }
151   return failure();
152 }
153 
154 /// Searches for the attribute that maps to the given TBAA metadata `node` and
155 /// attaches it to the imported operation if the lookup succeeds. Returns
156 /// failure otherwise.
157 static LogicalResult setTBAAAttr(const llvm::MDNode *node, Operation *op,
158                                  LLVM::ModuleImport &moduleImport) {
159   Attribute tbaaTagSym = moduleImport.lookupTBAAAttr(node);
160   if (!tbaaTagSym)
161     return failure();
162 
163   auto iface = dyn_cast<AliasAnalysisOpInterface>(op);
164   if (!iface)
165     return failure();
166 
167   iface.setTBAATags(ArrayAttr::get(iface.getContext(), tbaaTagSym));
168   return success();
169 }
170 
171 /// Looks up all the access group attributes that map to the access group nodes
172 /// starting from the access group metadata `node`, and attaches all of them to
173 /// the imported operation if the lookups succeed. Returns failure otherwise.
174 static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
175                                          Operation *op,
176                                          LLVM::ModuleImport &moduleImport) {
177   FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
178       moduleImport.lookupAccessGroupAttrs(node);
179   if (failed(accessGroups))
180     return failure();
181 
182   auto iface = dyn_cast<AccessGroupOpInterface>(op);
183   if (!iface)
184     return failure();
185 
186   iface.setAccessGroups(ArrayAttr::get(
187       iface.getContext(), llvm::to_vector_of<Attribute>(*accessGroups)));
188   return success();
189 }
190 
191 /// Converts the given loop metadata node to an MLIR loop annotation attribute
192 /// and attaches it to the imported operation if the translation succeeds.
193 /// Returns failure otherwise.
194 static LogicalResult setLoopAttr(const llvm::MDNode *node, Operation *op,
195                                  LLVM::ModuleImport &moduleImport) {
196   LoopAnnotationAttr attr =
197       moduleImport.translateLoopAnnotationAttr(node, op->getLoc());
198   if (!attr)
199     return failure();
200 
201   return TypeSwitch<Operation *, LogicalResult>(op)
202       .Case<LLVM::BrOp, LLVM::CondBrOp>([&](auto branchOp) {
203         branchOp.setLoopAnnotationAttr(attr);
204         return success();
205       })
206       .Default([](auto) { return failure(); });
207 }
208 
209 /// Looks up all the alias scope attributes that map to the alias scope nodes
210 /// starting from the alias scope metadata `node`, and attaches all of them to
211 /// the imported operation if the lookups succeed. Returns failure otherwise.
212 static LogicalResult setAliasScopesAttr(const llvm::MDNode *node, Operation *op,
213                                         LLVM::ModuleImport &moduleImport) {
214   FailureOr<SmallVector<AliasScopeAttr>> aliasScopes =
215       moduleImport.lookupAliasScopeAttrs(node);
216   if (failed(aliasScopes))
217     return failure();
218 
219   auto iface = dyn_cast<AliasAnalysisOpInterface>(op);
220   if (!iface)
221     return failure();
222 
223   iface.setAliasScopes(ArrayAttr::get(
224       iface.getContext(), llvm::to_vector_of<Attribute>(*aliasScopes)));
225   return success();
226 }
227 
228 /// Looks up all the alias scope attributes that map to the alias scope nodes
229 /// starting from the noalias metadata `node`, and attaches all of them to the
230 /// imported operation if the lookups succeed. Returns failure otherwise.
231 static LogicalResult setNoaliasScopesAttr(const llvm::MDNode *node,
232                                           Operation *op,
233                                           LLVM::ModuleImport &moduleImport) {
234   FailureOr<SmallVector<AliasScopeAttr>> noAliasScopes =
235       moduleImport.lookupAliasScopeAttrs(node);
236   if (failed(noAliasScopes))
237     return failure();
238 
239   auto iface = dyn_cast<AliasAnalysisOpInterface>(op);
240   if (!iface)
241     return failure();
242 
243   iface.setNoAliasScopes(ArrayAttr::get(
244       iface.getContext(), llvm::to_vector_of<Attribute>(*noAliasScopes)));
245   return success();
246 }
247 
248 /// Extracts an integer from the provided metadata `md` if possible. Returns
249 /// nullopt otherwise.
250 static std::optional<int32_t> parseIntegerMD(llvm::Metadata *md) {
251   auto *constant = dyn_cast_if_present<llvm::ConstantAsMetadata>(md);
252   if (!constant)
253     return {};
254 
255   auto *intConstant = dyn_cast<llvm::ConstantInt>(constant->getValue());
256   if (!intConstant)
257     return {};
258 
259   return intConstant->getValue().getSExtValue();
260 }
261 
262 /// Converts the provided metadata node `node` to an LLVM dialect
263 /// VecTypeHintAttr if possible.
264 static VecTypeHintAttr convertVecTypeHint(Builder builder, llvm::MDNode *node,
265                                           ModuleImport &moduleImport) {
266   if (!node || node->getNumOperands() != 2)
267     return {};
268 
269   auto *hintMD = dyn_cast<llvm::ValueAsMetadata>(node->getOperand(0).get());
270   if (!hintMD)
271     return {};
272   TypeAttr hint = TypeAttr::get(moduleImport.convertType(hintMD->getType()));
273 
274   std::optional<int32_t> optIsSigned =
275       parseIntegerMD(node->getOperand(1).get());
276   if (!optIsSigned)
277     return {};
278   bool isSigned = *optIsSigned != 0;
279 
280   return builder.getAttr<VecTypeHintAttr>(hint, isSigned);
281 }
282 
283 /// Converts the provided metadata node `node` to an MLIR DenseI32ArrayAttr if
284 /// possible.
285 static DenseI32ArrayAttr convertDenseI32Array(Builder builder,
286                                               llvm::MDNode *node) {
287   if (!node)
288     return {};
289   SmallVector<int32_t> vals;
290   for (const llvm::MDOperand &op : node->operands()) {
291     std::optional<int32_t> mdValue = parseIntegerMD(op.get());
292     if (!mdValue)
293       return {};
294     vals.push_back(*mdValue);
295   }
296   return builder.getDenseI32ArrayAttr(vals);
297 }
298 
299 /// Convert an `MDNode` to an MLIR `IntegerAttr` if possible.
300 static IntegerAttr convertIntegerMD(Builder builder, llvm::MDNode *node) {
301   if (!node || node->getNumOperands() != 1)
302     return {};
303   std::optional<int32_t> val = parseIntegerMD(node->getOperand(0));
304   if (!val)
305     return {};
306   return builder.getI32IntegerAttr(*val);
307 }
308 
309 static LogicalResult setVecTypeHintAttr(Builder &builder, llvm::MDNode *node,
310                                         Operation *op,
311                                         LLVM::ModuleImport &moduleImport) {
312   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
313   if (!funcOp)
314     return failure();
315 
316   VecTypeHintAttr attr = convertVecTypeHint(builder, node, moduleImport);
317   if (!attr)
318     return failure();
319 
320   funcOp.setVecTypeHintAttr(attr);
321   return success();
322 }
323 
324 static LogicalResult
325 setWorkGroupSizeHintAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
326   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
327   if (!funcOp)
328     return failure();
329 
330   DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
331   if (!attr)
332     return failure();
333 
334   funcOp.setWorkGroupSizeHintAttr(attr);
335   return success();
336 }
337 
338 static LogicalResult
339 setReqdWorkGroupSizeAttr(Builder &builder, llvm::MDNode *node, Operation *op) {
340   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
341   if (!funcOp)
342     return failure();
343 
344   DenseI32ArrayAttr attr = convertDenseI32Array(builder, node);
345   if (!attr)
346     return failure();
347 
348   funcOp.setReqdWorkGroupSizeAttr(attr);
349   return success();
350 }
351 
352 /// Converts the given intel required subgroup size metadata node to an MLIR
353 /// attribute and attaches it to the imported operation if the translation
354 /// succeeds. Returns failure otherwise.
355 static LogicalResult setIntelReqdSubGroupSizeAttr(Builder &builder,
356                                                   llvm::MDNode *node,
357                                                   Operation *op) {
358   auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
359   if (!funcOp)
360     return failure();
361 
362   IntegerAttr attr = convertIntegerMD(builder, node);
363   if (!attr)
364     return failure();
365 
366   funcOp.setIntelReqdSubGroupSizeAttr(attr);
367   return success();
368 }
369 
370 namespace {
371 
372 /// Implementation of the dialect interface that converts operations belonging
373 /// to the LLVM dialect to LLVM IR.
374 class LLVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
375 public:
376   using LLVMImportDialectInterface::LLVMImportDialectInterface;
377 
378   /// Converts the LLVM intrinsic to an MLIR LLVM dialect operation if a
379   /// conversion exits. Returns failure otherwise.
380   LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
381                                  LLVM::ModuleImport &moduleImport) const final {
382     return convertIntrinsicImpl(builder, inst, moduleImport);
383   }
384 
385   /// Attaches the given LLVM metadata to the imported operation if a conversion
386   /// to an LLVM dialect attribute exists and succeeds. Returns failure
387   /// otherwise.
388   LogicalResult setMetadataAttrs(OpBuilder &builder, unsigned kind,
389                                  llvm::MDNode *node, Operation *op,
390                                  LLVM::ModuleImport &moduleImport) const final {
391     // Call metadata specific handlers.
392     if (kind == llvm::LLVMContext::MD_prof)
393       return setProfilingAttr(builder, node, op, moduleImport);
394     if (kind == llvm::LLVMContext::MD_tbaa)
395       return setTBAAAttr(node, op, moduleImport);
396     if (kind == llvm::LLVMContext::MD_access_group)
397       return setAccessGroupsAttr(node, op, moduleImport);
398     if (kind == llvm::LLVMContext::MD_loop)
399       return setLoopAttr(node, op, moduleImport);
400     if (kind == llvm::LLVMContext::MD_alias_scope)
401       return setAliasScopesAttr(node, op, moduleImport);
402     if (kind == llvm::LLVMContext::MD_noalias)
403       return setNoaliasScopesAttr(node, op, moduleImport);
404 
405     llvm::LLVMContext &context = node->getContext();
406     if (kind == context.getMDKindID(vecTypeHintMDName))
407       return setVecTypeHintAttr(builder, node, op, moduleImport);
408     if (kind == context.getMDKindID(workGroupSizeHintMDName))
409       return setWorkGroupSizeHintAttr(builder, node, op);
410     if (kind == context.getMDKindID(reqdWorkGroupSizeMDName))
411       return setReqdWorkGroupSizeAttr(builder, node, op);
412     if (kind == context.getMDKindID(intelReqdSubGroupSizeMDName))
413       return setIntelReqdSubGroupSizeAttr(builder, node, op);
414 
415     // A handler for a supported metadata kind is missing.
416     llvm_unreachable("unknown metadata type");
417   }
418 
419   /// Returns the list of LLVM IR intrinsic identifiers that are convertible to
420   /// MLIR LLVM dialect intrinsics.
421   ArrayRef<unsigned> getSupportedIntrinsics() const final {
422     return getSupportedIntrinsicsImpl();
423   }
424 
425   /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR
426   /// LLVM dialect attributes.
427   ArrayRef<unsigned>
428   getSupportedMetadata(llvm::LLVMContext &context) const final {
429     return getSupportedMetadataImpl(context);
430   }
431 };
432 } // namespace
433 
434 void mlir::registerLLVMDialectImport(DialectRegistry &registry) {
435   registry.insert<LLVM::LLVMDialect>();
436   registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
437     dialect->addInterfaces<LLVMDialectLLVMIRImportInterface>();
438   });
439 }
440 
441 void mlir::registerLLVMDialectImport(MLIRContext &context) {
442   DialectRegistry registry;
443   registerLLVMDialectImport(registry);
444   context.appendDialectRegistry(registry);
445 }
446