xref: /llvm-project/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (revision eb6c4197d5263ed2e086925b2b2f032a19442d2b)
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
15 
16 #include "mlir/Analysis/DataLayoutAnalysis.h"
17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28 #include "mlir/Dialect/Utils/StaticValueUtils.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinAttributeInterfaces.h"
32 #include "mlir/IR/BuiltinAttributes.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Transforms/DialectConversion.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include <algorithm>
49 #include <functional>
50 #include <optional>
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
57 
58 using namespace mlir;
59 
60 #define PASS_NAME "convert-func-to-llvm"
61 
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65 
66 /// Return `true` if the `op` should use bare pointer calling convention.
67 static bool shouldUseBarePtrCallConv(Operation *op,
68                                      const LLVMTypeConverter *typeConverter) {
69   return (op && op->hasAttr(barePtrAttrName)) ||
70          typeConverter->getOptions().useBarePtrCallConv;
71 }
72 
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func,
76                                  SmallVectorImpl<NamedAttribute> &result) {
77   for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78     if (attr.getName() == linkageAttrName ||
79         attr.getName() == varargsAttrName ||
80         attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81       continue;
82     result.push_back(attr);
83   }
84 }
85 
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88                                  FunctionOpInterface funcOp,
89                                  LLVM::LLVMFuncOp wrapperFuncOp) {
90   auto argAttrs = funcOp.getAllArgAttrs();
91   if (!resultStructType) {
92     if (auto resAttrs = funcOp.getAllResultAttrs())
93       wrapperFuncOp.setAllResultAttrs(resAttrs);
94     if (argAttrs)
95       wrapperFuncOp.setAllArgAttrs(argAttrs);
96   } else {
97     SmallVector<Attribute> argAttributes;
98     // Only modify the argument and result attributes when the result is now
99     // an argument.
100     if (argAttrs) {
101       argAttributes.push_back(builder.getDictionaryAttr({}));
102       argAttributes.append(argAttrs.begin(), argAttrs.end());
103       wrapperFuncOp.setAllArgAttrs(argAttributes);
104     }
105   }
106   cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107       .setVisibility(funcOp.getVisibility());
108 }
109 
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119                                    const LLVMTypeConverter &typeConverter,
120                                    FunctionOpInterface funcOp,
121                                    LLVM::LLVMFuncOp newFuncOp) {
122   auto type = cast<FunctionType>(funcOp.getFunctionType());
123   auto [wrapperFuncType, resultStructType] =
124       typeConverter.convertFunctionTypeCWrapper(type);
125 
126   SmallVector<NamedAttribute> attributes;
127   filterFuncAttributes(funcOp, attributes);
128 
129   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132       /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133   propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
134 
135   OpBuilder::InsertionGuard guard(rewriter);
136   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
137 
138   SmallVector<Value, 8> args;
139   size_t argOffset = resultStructType ? 1 : 0;
140   for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141     Value arg = wrapperFuncOp.getArgument(index + argOffset);
142     if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143       Value loaded = rewriter.create<LLVM::LoadOp>(
144           loc, typeConverter.convertType(memrefType), arg);
145       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146       continue;
147     }
148     if (isa<UnrankedMemRefType>(argType)) {
149       Value loaded = rewriter.create<LLVM::LoadOp>(
150           loc, typeConverter.convertType(argType), arg);
151       UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152       continue;
153     }
154 
155     args.push_back(arg);
156   }
157 
158   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
159 
160   if (resultStructType) {
161     rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162                                    wrapperFuncOp.getArgument(0));
163     rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164   } else {
165     rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
166   }
167 }
168 
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder &builder, Location loc,
179                                  const LLVMTypeConverter &typeConverter,
180                                  FunctionOpInterface funcOp,
181                                  LLVM::LLVMFuncOp newFuncOp) {
182   OpBuilder::InsertionGuard guard(builder);
183 
184   auto [wrapperType, resultStructType] =
185       typeConverter.convertFunctionTypeCWrapper(
186           cast<FunctionType>(funcOp.getFunctionType()));
187   // This conversion can only fail if it could not convert one of the argument
188   // types. But since it has been applied to a non-wrapper function before, it
189   // should have failed earlier and not reach this point at all.
190   assert(wrapperType && "unexpected type conversion failure");
191 
192   SmallVector<NamedAttribute, 4> attributes;
193   filterFuncAttributes(funcOp, attributes);
194 
195   // Create the auxiliary function.
196   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198       wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199       /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200   propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
201 
202   // The wrapper that we synthetize here should only be visible in this module.
203   newFuncOp.setLinkage(LLVM::Linkage::Private);
204   builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
205 
206   // Get a ValueRange containing arguments.
207   FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
208   SmallVector<Value, 8> args;
209   args.reserve(type.getNumInputs());
210   ValueRange wrapperArgsRange(newFuncOp.getArguments());
211 
212   if (resultStructType) {
213     // Allocate the struct on the stack and pass the pointer.
214     Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215     Value one = builder.create<LLVM::ConstantOp>(
216         loc, typeConverter.convertType(builder.getIndexType()),
217         builder.getIntegerAttr(builder.getIndexType(), 1));
218     Value result =
219         builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220     args.push_back(result);
221   }
222 
223   // Iterate over the inputs of the original function and pack values into
224   // memref descriptors if the original type is a memref.
225   for (Type input : type.getInputs()) {
226     Value arg;
227     int numToDrop = 1;
228     auto memRefType = dyn_cast<MemRefType>(input);
229     auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230     if (memRefType || unrankedMemRefType) {
231       numToDrop = memRefType
232                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
233                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
234       Value packed =
235           memRefType
236               ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237                                        wrapperArgsRange.take_front(numToDrop))
238               : UnrankedMemRefDescriptor::pack(
239                     builder, loc, typeConverter, unrankedMemRefType,
240                     wrapperArgsRange.take_front(numToDrop));
241 
242       auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243       Value one = builder.create<LLVM::ConstantOp>(
244           loc, typeConverter.convertType(builder.getIndexType()),
245           builder.getIntegerAttr(builder.getIndexType(), 1));
246       Value allocated = builder.create<LLVM::AllocaOp>(
247           loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248       builder.create<LLVM::StoreOp>(loc, packed, allocated);
249       arg = allocated;
250     } else {
251       arg = wrapperArgsRange[0];
252     }
253 
254     args.push_back(arg);
255     wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
256   }
257   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
258 
259   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
260 
261   if (resultStructType) {
262     Value result =
263         builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264     builder.create<LLVM::ReturnOp>(loc, result);
265   } else {
266     builder.create<LLVM::ReturnOp>(loc, call.getResults());
267   }
268 }
269 
270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272 /// to LLVM pointer types.
273 static void restoreByValRefArgumentType(
274     ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275     ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276     ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) {
277   // Nothing to do for function declarations.
278   if (funcOp.isExternal())
279     return;
280 
281   ConversionPatternRewriter::InsertionGuard guard(rewriter);
282   rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
283 
284   for (const auto &[arg, oldArg, byValRefAttr] :
285        llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286     // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287     if (!byValRefAttr)
288       continue;
289 
290     // Insert load to retrieve the actual argument passed by value/reference.
291     assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292            "Expected LLVM pointer type for argument with "
293            "`llvm.byval`/`llvm.byref` attribute");
294     Type resTy = typeConverter.convertType(
295         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296 
297     auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298     rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
299   }
300 }
301 
302 FailureOr<LLVM::LLVMFuncOp>
303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304                                 ConversionPatternRewriter &rewriter,
305                                 const LLVMTypeConverter &converter) {
306   // Check the funcOp has `FunctionType`.
307   auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308   if (!funcTy)
309     return rewriter.notifyMatchFailure(
310         funcOp, "Only support FunctionOpInterface with FunctionType");
311 
312   // Keep track of the entry block arguments. They will be needed later.
313   SmallVector<BlockArgument> oldBlockArgs =
314       llvm::to_vector(funcOp.getArguments());
315 
316   // Convert the original function arguments. They are converted using the
317   // LLVMTypeConverter provided to this legalization pattern.
318   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
319   // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320   // overriden with an LLVM pointer type for later processing.
321   SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
322   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
323   auto llvmType = converter.convertFunctionSignature(
324       funcOp, varargsAttr && varargsAttr.getValue(),
325       shouldUseBarePtrCallConv(funcOp, &converter), result,
326       byValRefNonPtrAttrs);
327   if (!llvmType)
328     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
329 
330   // Create an LLVM function, use external linkage by default until MLIR
331   // functions have linkage.
332   LLVM::Linkage linkage = LLVM::Linkage::External;
333   if (funcOp->hasAttr(linkageAttrName)) {
334     auto attr =
335         dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
336     if (!attr) {
337       funcOp->emitError() << "Contains " << linkageAttrName
338                           << " attribute not of type LLVM::LinkageAttr";
339       return rewriter.notifyMatchFailure(
340           funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
341     }
342     linkage = attr.getLinkage();
343   }
344 
345   SmallVector<NamedAttribute, 4> attributes;
346   filterFuncAttributes(funcOp, attributes);
347   auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348       funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349       /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
350       attributes);
351   cast<FunctionOpInterface>(newFuncOp.getOperation())
352       .setVisibility(funcOp.getVisibility());
353 
354   // Create a memory effect attribute corresponding to readnone.
355   StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
356   if (funcOp->hasAttr(readnoneAttrName)) {
357     auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
358     if (!attr) {
359       funcOp->emitError() << "Contains " << readnoneAttrName
360                           << " attribute not of type UnitAttr";
361       return rewriter.notifyMatchFailure(
362           funcOp, "Contains readnone attribute not of type UnitAttr");
363     }
364     auto memoryAttr = LLVM::MemoryEffectsAttr::get(
365         rewriter.getContext(),
366         {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
367          LLVM::ModRefInfo::NoModRef});
368     newFuncOp.setMemoryEffectsAttr(memoryAttr);
369   }
370 
371   // Propagate argument/result attributes to all converted arguments/result
372   // obtained after converting a given original argument/result.
373   if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
374     assert(!resAttrDicts.empty() && "expected array to be non-empty");
375     if (funcOp.getNumResults() == 1)
376       newFuncOp.setAllResultAttrs(resAttrDicts);
377   }
378   if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
379     SmallVector<Attribute> newArgAttrs(
380         cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
381     for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
382       // Some LLVM IR attribute have a type attached to them. During FuncOp ->
383       // LLVMFuncOp conversion these types may have changed. Account for that
384       // change by converting attributes' types as well.
385       SmallVector<NamedAttribute, 4> convertedAttrs;
386       auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
387       convertedAttrs.reserve(attrsDict.size());
388       for (const NamedAttribute &attr : attrsDict) {
389         const auto convert = [&](const NamedAttribute &attr) {
390           return TypeAttr::get(converter.convertType(
391               cast<TypeAttr>(attr.getValue()).getValue()));
392         };
393         if (attr.getName().getValue() ==
394             LLVM::LLVMDialect::getByValAttrName()) {
395           convertedAttrs.push_back(rewriter.getNamedAttr(
396               LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
397         } else if (attr.getName().getValue() ==
398                    LLVM::LLVMDialect::getByRefAttrName()) {
399           convertedAttrs.push_back(rewriter.getNamedAttr(
400               LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
401         } else if (attr.getName().getValue() ==
402                    LLVM::LLVMDialect::getStructRetAttrName()) {
403           convertedAttrs.push_back(rewriter.getNamedAttr(
404               LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
405         } else if (attr.getName().getValue() ==
406                    LLVM::LLVMDialect::getInAllocaAttrName()) {
407           convertedAttrs.push_back(rewriter.getNamedAttr(
408               LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
409         } else {
410           convertedAttrs.push_back(attr);
411         }
412       }
413       auto mapping = result.getInputMapping(i);
414       assert(mapping && "unexpected deletion of function argument");
415       // Only attach the new argument attributes if there is a one-to-one
416       // mapping from old to new types. Otherwise, attributes might be
417       // attached to types that they do not support.
418       if (mapping->size == 1) {
419         newArgAttrs[mapping->inputNo] =
420             DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
421         continue;
422       }
423       // TODO: Implement custom handling for types that expand to multiple
424       // function arguments.
425       for (size_t j = 0; j < mapping->size; ++j)
426         newArgAttrs[mapping->inputNo + j] =
427             DictionaryAttr::get(rewriter.getContext(), {});
428     }
429     if (!newArgAttrs.empty())
430       newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
431   }
432 
433   rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
434                               newFuncOp.end());
435   // Convert just the entry block. The remaining unstructured control flow is
436   // converted by ControlFlowToLLVM.
437   if (!newFuncOp.getBody().empty())
438     rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
439                                       &converter);
440 
441   // Fix the type mismatch between the materialized `llvm.ptr` and the expected
442   // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
443   // function arguments.
444   restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
445                               oldBlockArgs, newFuncOp);
446 
447   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
448     if (funcOp->getAttrOfType<UnitAttr>(
449             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
450       if (newFuncOp.isVarArg())
451         return funcOp.emitError("C interface for variadic functions is not "
452                                 "supported yet.");
453 
454       if (newFuncOp.isExternal())
455         wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
456                              newFuncOp);
457       else
458         wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
459                                newFuncOp);
460     }
461   }
462 
463   return newFuncOp;
464 }
465 
466 namespace {
467 
468 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
469 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
470 /// information.
471 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
472   FuncOpConversion(const LLVMTypeConverter &converter)
473       : ConvertOpToLLVMPattern(converter) {}
474 
475   LogicalResult
476   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
477                   ConversionPatternRewriter &rewriter) const override {
478     FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
479         cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
480         *getTypeConverter());
481     if (failed(newFuncOp))
482       return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
483 
484     rewriter.eraseOp(funcOp);
485     return success();
486   }
487 };
488 
489 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
490   using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
491 
492   LogicalResult
493   matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
494                   ConversionPatternRewriter &rewriter) const override {
495     auto type = typeConverter->convertType(op.getResult().getType());
496     if (!type || !LLVM::isCompatibleType(type))
497       return rewriter.notifyMatchFailure(op, "failed to convert result type");
498 
499     auto newOp =
500         rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
501     for (const NamedAttribute &attr : op->getAttrs()) {
502       if (attr.getName().strref() == "value")
503         continue;
504       newOp->setAttr(attr.getName(), attr.getValue());
505     }
506     rewriter.replaceOp(op, newOp->getResults());
507     return success();
508   }
509 };
510 
511 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
512 // passes the pointer to the MemRef across function boundaries.
513 template <typename CallOpType>
514 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
515   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
516   using Super = CallOpInterfaceLowering<CallOpType>;
517   using Base = ConvertOpToLLVMPattern<CallOpType>;
518 
519   LogicalResult matchAndRewriteImpl(CallOpType callOp,
520                                     typename CallOpType::Adaptor adaptor,
521                                     ConversionPatternRewriter &rewriter,
522                                     bool useBarePtrCallConv = false) const {
523     // Pack the result types into a struct.
524     Type packedResult = nullptr;
525     unsigned numResults = callOp.getNumResults();
526     auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
527 
528     if (numResults != 0) {
529       if (!(packedResult = this->getTypeConverter()->packFunctionResults(
530                 resultTypes, useBarePtrCallConv)))
531         return failure();
532     }
533 
534     if (useBarePtrCallConv) {
535       for (auto it : callOp->getOperands()) {
536         Type operandType = it.getType();
537         if (isa<UnrankedMemRefType>(operandType)) {
538           // Unranked memref is not supported in the bare pointer calling
539           // convention.
540           return failure();
541         }
542       }
543     }
544     auto promoted = this->getTypeConverter()->promoteOperands(
545         callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
546         adaptor.getOperands(), rewriter, useBarePtrCallConv);
547     auto newOp = rewriter.create<LLVM::CallOp>(
548         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
549         promoted, callOp->getAttrs());
550 
551     newOp.getProperties().operandSegmentSizes = {
552         static_cast<int32_t>(promoted.size()), 0};
553     newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
554 
555     SmallVector<Value, 4> results;
556     if (numResults < 2) {
557       // If < 2 results, packing did not do anything and we can just return.
558       results.append(newOp.result_begin(), newOp.result_end());
559     } else {
560       // Otherwise, it had been converted to an operation producing a structure.
561       // Extract individual results from the structure and return them as list.
562       results.reserve(numResults);
563       for (unsigned i = 0; i < numResults; ++i) {
564         results.push_back(rewriter.create<LLVM::ExtractValueOp>(
565             callOp.getLoc(), newOp->getResult(0), i));
566       }
567     }
568 
569     if (useBarePtrCallConv) {
570       // For the bare-ptr calling convention, promote memref results to
571       // descriptors.
572       assert(results.size() == resultTypes.size() &&
573              "The number of arguments and types doesn't match");
574       this->getTypeConverter()->promoteBarePtrsToDescriptors(
575           rewriter, callOp.getLoc(), resultTypes, results);
576     } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
577                                                     resultTypes, results,
578                                                     /*toDynamic=*/false))) {
579       return failure();
580     }
581 
582     rewriter.replaceOp(callOp, results);
583     return success();
584   }
585 };
586 
587 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
588 public:
589   CallOpLowering(const LLVMTypeConverter &typeConverter,
590                  // Can be nullptr.
591                  const SymbolTable *symbolTable, PatternBenefit benefit = 1)
592       : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
593         symbolTable(symbolTable) {}
594 
595   LogicalResult
596   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
597                   ConversionPatternRewriter &rewriter) const override {
598     bool useBarePtrCallConv = false;
599     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
600       useBarePtrCallConv = true;
601     } else if (symbolTable != nullptr) {
602       // Fast lookup.
603       Operation *callee =
604           symbolTable->lookup(callOp.getCalleeAttr().getValue());
605       useBarePtrCallConv =
606           callee != nullptr && callee->hasAttr(barePtrAttrName);
607     } else {
608       // Warning: This is a linear lookup.
609       Operation *callee =
610           SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
611       useBarePtrCallConv =
612           callee != nullptr && callee->hasAttr(barePtrAttrName);
613     }
614     return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
615   }
616 
617 private:
618   const SymbolTable *symbolTable = nullptr;
619 };
620 
621 struct CallIndirectOpLowering
622     : public CallOpInterfaceLowering<func::CallIndirectOp> {
623   using Super::Super;
624 
625   LogicalResult
626   matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
627                   ConversionPatternRewriter &rewriter) const override {
628     return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
629   }
630 };
631 
632 struct UnrealizedConversionCastOpLowering
633     : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
634   using ConvertOpToLLVMPattern<
635       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
636 
637   LogicalResult
638   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
639                   ConversionPatternRewriter &rewriter) const override {
640     SmallVector<Type> convertedTypes;
641     if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
642                                               convertedTypes)) &&
643         convertedTypes == adaptor.getInputs().getTypes()) {
644       rewriter.replaceOp(op, adaptor.getInputs());
645       return success();
646     }
647 
648     convertedTypes.clear();
649     if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
650                                               convertedTypes)) &&
651         convertedTypes == op.getOutputs().getType()) {
652       rewriter.replaceOp(op, adaptor.getInputs());
653       return success();
654     }
655     return failure();
656   }
657 };
658 
659 // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
660 // `ReturnOp` interacts with the function signature and must have as many
661 // operands as the function has return values.  Because in LLVM IR, functions
662 // can only return 0 or 1 value, we pack multiple values into a structure type.
663 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
664 // necessary before returning it
665 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
666   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
667 
668   LogicalResult
669   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
670                   ConversionPatternRewriter &rewriter) const override {
671     Location loc = op.getLoc();
672     unsigned numArguments = op.getNumOperands();
673     SmallVector<Value, 4> updatedOperands;
674 
675     auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
676     bool useBarePtrCallConv =
677         shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
678     if (useBarePtrCallConv) {
679       // For the bare-ptr calling convention, extract the aligned pointer to
680       // be returned from the memref descriptor.
681       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
682         Type oldTy = std::get<0>(it).getType();
683         Value newOperand = std::get<1>(it);
684         if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
685                                           cast<BaseMemRefType>(oldTy))) {
686           MemRefDescriptor memrefDesc(newOperand);
687           newOperand = memrefDesc.allocatedPtr(rewriter, loc);
688         } else if (isa<UnrankedMemRefType>(oldTy)) {
689           // Unranked memref is not supported in the bare pointer calling
690           // convention.
691           return failure();
692         }
693         updatedOperands.push_back(newOperand);
694       }
695     } else {
696       updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
697       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
698                                     updatedOperands,
699                                     /*toDynamic=*/true);
700     }
701 
702     // If ReturnOp has 0 or 1 operand, create it and return immediately.
703     if (numArguments <= 1) {
704       rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
705           op, TypeRange(), updatedOperands, op->getAttrs());
706       return success();
707     }
708 
709     // Otherwise, we need to pack the arguments into an LLVM struct type before
710     // returning.
711     auto packedType = getTypeConverter()->packFunctionResults(
712         op.getOperandTypes(), useBarePtrCallConv);
713     if (!packedType) {
714       return rewriter.notifyMatchFailure(op, "could not convert result types");
715     }
716 
717     Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
718     for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
719       packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
720     }
721     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
722                                                 op->getAttrs());
723     return success();
724   }
725 };
726 } // namespace
727 
728 void mlir::populateFuncToLLVMFuncOpConversionPattern(
729     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
730   patterns.add<FuncOpConversion>(converter);
731 }
732 
733 void mlir::populateFuncToLLVMConversionPatterns(
734     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
735     const SymbolTable *symbolTable) {
736   populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
737   patterns.add<CallIndirectOpLowering>(converter);
738   patterns.add<CallOpLowering>(converter, symbolTable);
739   patterns.add<ConstantOpLowering>(converter);
740   patterns.add<ReturnOpLowering>(converter);
741 }
742 
743 namespace {
744 /// A pass converting Func operations into the LLVM IR dialect.
745 struct ConvertFuncToLLVMPass
746     : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
747   using Base::Base;
748 
749   /// Run the dialect converter on the module.
750   void runOnOperation() override {
751     ModuleOp m = getOperation();
752     StringRef dataLayout;
753     auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
754         m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
755     if (dataLayoutAttr)
756       dataLayout = dataLayoutAttr.getValue();
757 
758     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
759             dataLayout, [this](const Twine &message) {
760               getOperation().emitError() << message.str();
761             }))) {
762       signalPassFailure();
763       return;
764     }
765 
766     const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
767 
768     LowerToLLVMOptions options(&getContext(),
769                                dataLayoutAnalysis.getAtOrAbove(m));
770     options.useBarePtrCallConv = useBarePtrCallConv;
771     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
772       options.overrideIndexBitwidth(indexBitwidth);
773     options.dataLayout = llvm::DataLayout(dataLayout);
774 
775     LLVMTypeConverter typeConverter(&getContext(), options,
776                                     &dataLayoutAnalysis);
777 
778     std::optional<SymbolTable> optSymbolTable = std::nullopt;
779     const SymbolTable *symbolTable = nullptr;
780     if (!options.useBarePtrCallConv) {
781       optSymbolTable.emplace(m);
782       symbolTable = &optSymbolTable.value();
783     }
784 
785     RewritePatternSet patterns(&getContext());
786     populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
787 
788     LLVMConversionTarget target(getContext());
789     if (failed(applyPartialConversion(m, target, std::move(patterns))))
790       signalPassFailure();
791   }
792 };
793 
794 struct SetLLVMModuleDataLayoutPass
795     : public impl::SetLLVMModuleDataLayoutPassBase<
796           SetLLVMModuleDataLayoutPass> {
797   using Base::Base;
798 
799   /// Run the dialect converter on the module.
800   void runOnOperation() override {
801     if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
802             this->dataLayout, [this](const Twine &message) {
803               getOperation().emitError() << message.str();
804             }))) {
805       signalPassFailure();
806       return;
807     }
808     ModuleOp m = getOperation();
809     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
810                StringAttr::get(m.getContext(), this->dataLayout));
811   }
812 };
813 } // namespace
814 
815 //===----------------------------------------------------------------------===//
816 // ConvertToLLVMPatternInterface implementation
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 /// Implement the interface to convert Func to LLVM.
821 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
822   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
823   /// Hook for derived dialect interface to provide conversion patterns
824   /// and mark dialect legal for the conversion target.
825   void populateConvertToLLVMConversionPatterns(
826       ConversionTarget &target, LLVMTypeConverter &typeConverter,
827       RewritePatternSet &patterns) const final {
828     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
829   }
830 };
831 } // namespace
832 
833 void mlir::registerConvertFuncToLLVMInterface(DialectRegistry &registry) {
834   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
835     dialect->addInterfaces<FuncToLLVMDialectInterface>();
836   });
837 }
838