xref: /llvm-project/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp (revision 1dfb104eac73863b06751bea225ffa6ef589577f)
1 //===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR LLVM dialect and LLVM IR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InlineAsm.h"
21 #include "llvm/IR/MDBuilder.h"
22 #include "llvm/IR/MatrixBuilder.h"
23 #include "llvm/IR/Operator.h"
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 using mlir::LLVM::detail::getLLVMConstant;
28 
29 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
30 
31 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
32   using llvmFMF = llvm::FastMathFlags;
33   using FuncT = void (llvmFMF::*)(bool);
34   const std::pair<FastmathFlags, FuncT> handlers[] = {
35       // clang-format off
36       {FastmathFlags::nnan,     &llvmFMF::setNoNaNs},
37       {FastmathFlags::ninf,     &llvmFMF::setNoInfs},
38       {FastmathFlags::nsz,      &llvmFMF::setNoSignedZeros},
39       {FastmathFlags::arcp,     &llvmFMF::setAllowReciprocal},
40       {FastmathFlags::contract, &llvmFMF::setAllowContract},
41       {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
42       {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
43       // clang-format on
44   };
45   llvm::FastMathFlags ret;
46   ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
47   for (auto it : handlers)
48     if (bitEnumContainsAll(fmfMlir, it.first))
49       (ret.*(it.second))(true);
50   return ret;
51 }
52 
53 /// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
54 static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) {
55   SmallVector<unsigned> position;
56   llvm::append_range(position, indices);
57   return position;
58 }
59 
60 /// Convert an LLVM type to a string for printing in diagnostics.
61 static std::string diagStr(const llvm::Type *type) {
62   std::string str;
63   llvm::raw_string_ostream os(str);
64   type->print(os);
65   return str;
66 }
67 
68 /// Get the declaration of an overloaded llvm intrinsic. First we get the
69 /// overloaded argument types and/or result type from the CallIntrinsicOp, and
70 /// then use those to get the correct declaration of the overloaded intrinsic.
71 static FailureOr<llvm::Function *>
72 getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
73                          llvm::Module *module,
74                          LLVM::ModuleTranslation &moduleTranslation) {
75   SmallVector<llvm::Type *, 8> allArgTys;
76   for (Type type : op->getOperandTypes())
77     allArgTys.push_back(moduleTranslation.convertType(type));
78 
79   llvm::Type *resTy;
80   if (op.getNumResults() == 0)
81     resTy = llvm::Type::getVoidTy(module->getContext());
82   else
83     resTy = moduleTranslation.convertType(op.getResult(0).getType());
84 
85   // ATM we do not support variadic intrinsics.
86   llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);
87 
88   SmallVector<llvm::Intrinsic::IITDescriptor, 8> table;
89   getIntrinsicInfoTableEntries(id, table);
90   ArrayRef<llvm::Intrinsic::IITDescriptor> tableRef = table;
91 
92   SmallVector<llvm::Type *, 8> overloadedArgTys;
93   if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
94                                                overloadedArgTys) !=
95       llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
96     return mlir::emitError(op.getLoc(), "call intrinsic signature ")
97            << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr()
98            << " does not match any of the overloads";
99   }
100 
101   ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
102   return llvm::Intrinsic::getOrInsertDeclaration(module, id,
103                                                  overloadedArgTysRef);
104 }
105 
106 static llvm::OperandBundleDef
107 convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
108                      LLVM::ModuleTranslation &moduleTranslation) {
109   std::vector<llvm::Value *> operands;
110   operands.reserve(bundleOperands.size());
111   for (Value bundleArg : bundleOperands)
112     operands.push_back(moduleTranslation.lookupValue(bundleArg));
113   return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
114 }
115 
116 static SmallVector<llvm::OperandBundleDef>
117 convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags,
118                       LLVM::ModuleTranslation &moduleTranslation) {
119   SmallVector<llvm::OperandBundleDef> bundles;
120   bundles.reserve(bundleOperands.size());
121 
122   for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
123     StringRef tag = cast<StringAttr>(tagAttr).getValue();
124     bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
125   }
126   return bundles;
127 }
128 
129 static SmallVector<llvm::OperandBundleDef>
130 convertOperandBundles(OperandRangeRange bundleOperands,
131                       std::optional<ArrayAttr> bundleTags,
132                       LLVM::ModuleTranslation &moduleTranslation) {
133   if (!bundleTags)
134     return {};
135   return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
136 }
137 
138 /// Builder for LLVM_CallIntrinsicOp
139 static LogicalResult
140 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
141                            LLVM::ModuleTranslation &moduleTranslation) {
142   llvm::Module *module = builder.GetInsertBlock()->getModule();
143   llvm::Intrinsic::ID id =
144       llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr());
145   if (!id)
146     return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
147            << op.getIntrinAttr();
148 
149   llvm::Function *fn = nullptr;
150   if (llvm::Intrinsic::isOverloaded(id)) {
151     auto fnOrFailure =
152         getOverloadedDeclaration(op, id, module, moduleTranslation);
153     if (failed(fnOrFailure))
154       return failure();
155     fn = *fnOrFailure;
156   } else {
157     fn = llvm::Intrinsic::getOrInsertDeclaration(module, id, {});
158   }
159 
160   // Check the result type of the call.
161   const llvm::Type *intrinType =
162       op.getNumResults() == 0
163           ? llvm::Type::getVoidTy(module->getContext())
164           : moduleTranslation.convertType(op.getResultTypes().front());
165   if (intrinType != fn->getReturnType()) {
166     return mlir::emitError(op.getLoc(), "intrinsic call returns ")
167            << diagStr(intrinType) << " but " << op.getIntrinAttr()
168            << " actually returns " << diagStr(fn->getReturnType());
169   }
170 
171   // Check the argument types of the call. If the function is variadic, check
172   // the subrange of required arguments.
173   if (!fn->getFunctionType()->isVarArg() &&
174       op.getArgs().size() != fn->arg_size()) {
175     return mlir::emitError(op.getLoc(), "intrinsic call has ")
176            << op.getArgs().size() << " operands but " << op.getIntrinAttr()
177            << " expects " << fn->arg_size();
178   }
179   if (fn->getFunctionType()->isVarArg() &&
180       op.getArgs().size() < fn->arg_size()) {
181     return mlir::emitError(op.getLoc(), "intrinsic call has ")
182            << op.getArgs().size() << " operands but variadic "
183            << op.getIntrinAttr() << " expects at least " << fn->arg_size();
184   }
185   // Check the arguments up to the number the function requires.
186   for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
187     const llvm::Type *expected = fn->getArg(i)->getType();
188     const llvm::Type *actual =
189         moduleTranslation.convertType(op.getOperandTypes()[i]);
190     if (actual != expected) {
191       return mlir::emitError(op.getLoc(), "intrinsic call operand #")
192              << i << " has type " << diagStr(actual) << " but "
193              << op.getIntrinAttr() << " expects " << diagStr(expected);
194     }
195   }
196 
197   FastmathFlagsInterface itf = op;
198   builder.setFastMathFlags(getFastmathFlags(itf));
199 
200   auto *inst = builder.CreateCall(
201       fn, moduleTranslation.lookupValues(op.getArgs()),
202       convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
203                             moduleTranslation));
204   if (op.getNumResults() == 1)
205     moduleTranslation.mapValue(op->getResults().front()) = inst;
206   return success();
207 }
208 
209 static void convertLinkerOptionsOp(ArrayAttr options,
210                                    llvm::IRBuilderBase &builder,
211                                    LLVM::ModuleTranslation &moduleTranslation) {
212   llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
213   llvm::LLVMContext &context = llvmModule->getContext();
214   llvm::NamedMDNode *linkerMDNode =
215       llvmModule->getOrInsertNamedMetadata("llvm.linker.options");
216   SmallVector<llvm::Metadata *> MDNodes;
217   MDNodes.reserve(options.size());
218   for (auto s : options.getAsRange<StringAttr>()) {
219     auto *MDNode = llvm::MDString::get(context, s.getValue());
220     MDNodes.push_back(MDNode);
221   }
222 
223   auto *listMDNode = llvm::MDTuple::get(context, MDNodes);
224   linkerMDNode->addOperand(listMDNode);
225 }
226 
227 static LogicalResult
228 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
229                      LLVM::ModuleTranslation &moduleTranslation) {
230 
231   llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
232   if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
233     builder.setFastMathFlags(getFastmathFlags(fmf));
234 
235 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
236 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
237 
238   // Emit function calls.  If the "callee" attribute is present, this is a
239   // direct function call and we also need to look up the remapped function
240   // itself.  Otherwise, this is an indirect call and the callee is the first
241   // operand, look it up as a normal value.
242   if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
243     auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
244     SmallVector<llvm::OperandBundleDef> opBundles =
245         convertOperandBundles(callOp.getOpBundleOperands(),
246                               callOp.getOpBundleTags(), moduleTranslation);
247     ArrayRef<llvm::Value *> operandsRef(operands);
248     llvm::CallInst *call;
249     if (auto attr = callOp.getCalleeAttr()) {
250       call =
251           builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
252                              operandsRef, opBundles);
253     } else {
254       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
255           moduleTranslation.convertType(callOp.getCalleeFunctionType()));
256       call = builder.CreateCall(calleeType, operandsRef.front(),
257                                 operandsRef.drop_front(), opBundles);
258     }
259     call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
260     call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
261     if (callOp.getConvergentAttr())
262       call->addFnAttr(llvm::Attribute::Convergent);
263     if (callOp.getNoUnwindAttr())
264       call->addFnAttr(llvm::Attribute::NoUnwind);
265     if (callOp.getWillReturnAttr())
266       call->addFnAttr(llvm::Attribute::WillReturn);
267 
268     if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
269       llvm::MemoryEffects memEffects =
270           llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
271                               convertModRefInfoToLLVM(memAttr.getArgMem())) |
272           llvm::MemoryEffects(
273               llvm::MemoryEffects::Location::InaccessibleMem,
274               convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
275           llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
276                               convertModRefInfoToLLVM(memAttr.getOther()));
277       call->setMemoryEffects(memEffects);
278     }
279 
280     moduleTranslation.setAccessGroupsMetadata(callOp, call);
281     moduleTranslation.setAliasScopeMetadata(callOp, call);
282     moduleTranslation.setTBAAMetadata(callOp, call);
283     // If the called function has a result, remap the corresponding value.  Note
284     // that LLVM IR dialect CallOp has either 0 or 1 result.
285     if (opInst.getNumResults() != 0)
286       moduleTranslation.mapValue(opInst.getResult(0), call);
287     // Check that LLVM call returns void for 0-result functions.
288     else if (!call->getType()->isVoidTy())
289       return failure();
290     moduleTranslation.mapCall(callOp, call);
291     return success();
292   }
293 
294   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
295     // TODO: refactor function type creation which usually occurs in std-LLVM
296     // conversion.
297     SmallVector<Type, 8> operandTypes;
298     llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
299 
300     Type resultType;
301     if (inlineAsmOp.getNumResults() == 0) {
302       resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
303     } else {
304       assert(inlineAsmOp.getNumResults() == 1);
305       resultType = inlineAsmOp.getResultTypes()[0];
306     }
307     auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
308     llvm::InlineAsm *inlineAsmInst =
309         inlineAsmOp.getAsmDialect()
310             ? llvm::InlineAsm::get(
311                   static_cast<llvm::FunctionType *>(
312                       moduleTranslation.convertType(ft)),
313                   inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
314                   inlineAsmOp.getHasSideEffects(),
315                   inlineAsmOp.getIsAlignStack(),
316                   convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
317             : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
318                                        moduleTranslation.convertType(ft)),
319                                    inlineAsmOp.getAsmString(),
320                                    inlineAsmOp.getConstraints(),
321                                    inlineAsmOp.getHasSideEffects(),
322                                    inlineAsmOp.getIsAlignStack());
323     llvm::CallInst *inst = builder.CreateCall(
324         inlineAsmInst,
325         moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
326     if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
327       llvm::AttributeList attrList;
328       for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
329         Attribute attr = it.value();
330         if (!attr)
331           continue;
332         DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
333         TypeAttr tAttr =
334             cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
335         llvm::AttrBuilder b(moduleTranslation.getLLVMContext());
336         llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue());
337         b.addTypeAttr(llvm::Attribute::ElementType, ty);
338         // shift to account for the returned value (this is always 1 aggregate
339         // value in LLVM).
340         int shift = (opInst.getNumResults() > 0) ? 1 : 0;
341         attrList = attrList.addAttributesAtIndex(
342             moduleTranslation.getLLVMContext(), it.index() + shift, b);
343       }
344       inst->setAttributes(attrList);
345     }
346 
347     if (opInst.getNumResults() != 0)
348       moduleTranslation.mapValue(opInst.getResult(0), inst);
349     return success();
350   }
351 
352   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
353     auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
354     SmallVector<llvm::OperandBundleDef> opBundles =
355         convertOperandBundles(invOp.getOpBundleOperands(),
356                               invOp.getOpBundleTags(), moduleTranslation);
357     ArrayRef<llvm::Value *> operandsRef(operands);
358     llvm::InvokeInst *result;
359     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
360       result = builder.CreateInvoke(
361           moduleTranslation.lookupFunction(attr.getValue()),
362           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
363           moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
364           opBundles);
365     } else {
366       llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
367           moduleTranslation.convertType(invOp.getCalleeFunctionType()));
368       result = builder.CreateInvoke(
369           calleeType, operandsRef.front(),
370           moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
371           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
372           operandsRef.drop_front(), opBundles);
373     }
374     result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
375     moduleTranslation.mapBranch(invOp, result);
376     // InvokeOp can only have 0 or 1 result
377     if (invOp->getNumResults() != 0) {
378       moduleTranslation.mapValue(opInst.getResult(0), result);
379       return success();
380     }
381     return success(result->getType()->isVoidTy());
382   }
383 
384   if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
385     llvm::Type *ty = moduleTranslation.convertType(lpOp.getType());
386     llvm::LandingPadInst *lpi =
387         builder.CreateLandingPad(ty, lpOp.getNumOperands());
388     lpi->setCleanup(lpOp.getCleanup());
389 
390     // Add clauses
391     for (llvm::Value *operand :
392          moduleTranslation.lookupValues(lpOp.getOperands())) {
393       // All operands should be constant - checked by verifier
394       if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
395         lpi->addClause(constOperand);
396     }
397     moduleTranslation.mapValue(lpOp.getResult(), lpi);
398     return success();
399   }
400 
401   // Emit branches.  We need to look up the remapped blocks and ignore the
402   // block arguments that were transformed into PHI nodes.
403   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
404     llvm::BranchInst *branch =
405         builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
406     moduleTranslation.mapBranch(&opInst, branch);
407     moduleTranslation.setLoopMetadata(&opInst, branch);
408     return success();
409   }
410   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
411     llvm::BranchInst *branch = builder.CreateCondBr(
412         moduleTranslation.lookupValue(condbrOp.getOperand(0)),
413         moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
414         moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
415     moduleTranslation.mapBranch(&opInst, branch);
416     moduleTranslation.setLoopMetadata(&opInst, branch);
417     return success();
418   }
419   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
420     llvm::SwitchInst *switchInst = builder.CreateSwitch(
421         moduleTranslation.lookupValue(switchOp.getValue()),
422         moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
423         switchOp.getCaseDestinations().size());
424 
425     // Handle switch with zero cases.
426     if (!switchOp.getCaseValues())
427       return success();
428 
429     auto *ty = llvm::cast<llvm::IntegerType>(
430         moduleTranslation.convertType(switchOp.getValue().getType()));
431     for (auto i :
432          llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
433                    switchOp.getCaseDestinations()))
434       switchInst->addCase(
435           llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
436           moduleTranslation.lookupBlock(std::get<1>(i)));
437 
438     moduleTranslation.mapBranch(&opInst, switchInst);
439     return success();
440   }
441 
442   // Emit addressof.  We need to look up the global value referenced by the
443   // operation and store it in the MLIR-to-LLVM value mapping.  This does not
444   // emit any LLVM instruction.
445   if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
446     LLVM::GlobalOp global =
447         addressOfOp.getGlobal(moduleTranslation.symbolTable());
448     LLVM::LLVMFuncOp function =
449         addressOfOp.getFunction(moduleTranslation.symbolTable());
450 
451     // The verifier should not have allowed this.
452     assert((global || function) &&
453            "referencing an undefined global or function");
454 
455     moduleTranslation.mapValue(
456         addressOfOp.getResult(),
457         global ? moduleTranslation.lookupGlobal(global)
458                : moduleTranslation.lookupFunction(function.getName()));
459     return success();
460   }
461 
462   return failure();
463 }
464 
465 namespace {
466 /// Implementation of the dialect interface that converts operations belonging
467 /// to the LLVM dialect to LLVM IR.
468 class LLVMDialectLLVMIRTranslationInterface
469     : public LLVMTranslationDialectInterface {
470 public:
471   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
472 
473   /// Translates the given operation to LLVM IR using the provided IR builder
474   /// and saving the state in `moduleTranslation`.
475   LogicalResult
476   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
477                    LLVM::ModuleTranslation &moduleTranslation) const final {
478     return convertOperationImpl(*op, builder, moduleTranslation);
479   }
480 };
481 } // namespace
482 
483 void mlir::registerLLVMDialectTranslation(DialectRegistry &registry) {
484   registry.insert<LLVM::LLVMDialect>();
485   registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
486     dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
487   });
488 }
489 
490 void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
491   DialectRegistry registry;
492   registerLLVMDialectTranslation(registry);
493   context.appendDialectRegistry(registry);
494 }
495