xref: /llvm-project/mlir/lib/Dialect/Func/IR/FuncOps.cpp (revision 663e9cec9c96169aa4e72ab9b6bf08b2d6603093)
123aa5a74SRiver Riddle //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
223aa5a74SRiver Riddle //
323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623aa5a74SRiver Riddle //
723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
823aa5a74SRiver Riddle 
923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1023aa5a74SRiver Riddle 
11b43c5049SJustin Fargnoli #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1323aa5a74SRiver Riddle #include "mlir/IR/BuiltinOps.h"
1423aa5a74SRiver Riddle #include "mlir/IR/BuiltinTypes.h"
154d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
1623aa5a74SRiver Riddle #include "mlir/IR/Matchers.h"
1723aa5a74SRiver Riddle #include "mlir/IR/OpImplementation.h"
1823aa5a74SRiver Riddle #include "mlir/IR/PatternMatch.h"
1923aa5a74SRiver Riddle #include "mlir/IR/TypeUtilities.h"
2023aa5a74SRiver Riddle #include "mlir/IR/Value.h"
2134a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h"
2223aa5a74SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
2323aa5a74SRiver Riddle #include "llvm/ADT/APFloat.h"
2436550692SRiver Riddle #include "llvm/ADT/MapVector.h"
2523aa5a74SRiver Riddle #include "llvm/ADT/STLExtras.h"
2623aa5a74SRiver Riddle #include "llvm/Support/FormatVariadic.h"
2723aa5a74SRiver Riddle #include "llvm/Support/raw_ostream.h"
2823aa5a74SRiver Riddle #include <numeric>
2923aa5a74SRiver Riddle 
3023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
3123aa5a74SRiver Riddle 
3223aa5a74SRiver Riddle using namespace mlir;
3323aa5a74SRiver Riddle using namespace mlir::func;
3423aa5a74SRiver Riddle 
3523aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
3623aa5a74SRiver Riddle // FuncDialect
3723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
3823aa5a74SRiver Riddle 
3923aa5a74SRiver Riddle void FuncDialect::initialize() {
4023aa5a74SRiver Riddle   addOperations<
4123aa5a74SRiver Riddle #define GET_OP_LIST
4223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
4323aa5a74SRiver Riddle       >();
4435d55f28SJustin Fargnoli   declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
4535d55f28SJustin Fargnoli   declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
46513cdb82SJustin Fargnoli   declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
47513cdb82SJustin Fargnoli                             FuncOp, ReturnOp>();
4823aa5a74SRiver Riddle }
4923aa5a74SRiver Riddle 
5023aa5a74SRiver Riddle /// Materialize a single constant operation from a given attribute value with
5123aa5a74SRiver Riddle /// the desired resultant type.
5223aa5a74SRiver Riddle Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
5323aa5a74SRiver Riddle                                             Type type, Location loc) {
5423aa5a74SRiver Riddle   if (ConstantOp::isBuildableWith(value, type))
5523aa5a74SRiver Riddle     return builder.create<ConstantOp>(loc, type,
56c1fa60b4STres Popp                                       llvm::cast<FlatSymbolRefAttr>(value));
5723aa5a74SRiver Riddle   return nullptr;
5823aa5a74SRiver Riddle }
5923aa5a74SRiver Riddle 
6023aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
6123aa5a74SRiver Riddle // CallOp
6223aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
6323aa5a74SRiver Riddle 
6423aa5a74SRiver Riddle LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6523aa5a74SRiver Riddle   // Check that the callee attribute was specified.
6623aa5a74SRiver Riddle   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
6723aa5a74SRiver Riddle   if (!fnAttr)
6823aa5a74SRiver Riddle     return emitOpError("requires a 'callee' symbol reference attribute");
6923aa5a74SRiver Riddle   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
7023aa5a74SRiver Riddle   if (!fn)
7123aa5a74SRiver Riddle     return emitOpError() << "'" << fnAttr.getValue()
7223aa5a74SRiver Riddle                          << "' does not reference a valid function";
7323aa5a74SRiver Riddle 
7423aa5a74SRiver Riddle   // Verify that the operand and result types match the callee.
754a3460a7SRiver Riddle   auto fnType = fn.getFunctionType();
7623aa5a74SRiver Riddle   if (fnType.getNumInputs() != getNumOperands())
7723aa5a74SRiver Riddle     return emitOpError("incorrect number of operands for callee");
7823aa5a74SRiver Riddle 
7923aa5a74SRiver Riddle   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
8023aa5a74SRiver Riddle     if (getOperand(i).getType() != fnType.getInput(i))
8123aa5a74SRiver Riddle       return emitOpError("operand type mismatch: expected operand type ")
8223aa5a74SRiver Riddle              << fnType.getInput(i) << ", but provided "
8323aa5a74SRiver Riddle              << getOperand(i).getType() << " for operand number " << i;
8423aa5a74SRiver Riddle 
8523aa5a74SRiver Riddle   if (fnType.getNumResults() != getNumResults())
8623aa5a74SRiver Riddle     return emitOpError("incorrect number of results for callee");
8723aa5a74SRiver Riddle 
8823aa5a74SRiver Riddle   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
8923aa5a74SRiver Riddle     if (getResult(i).getType() != fnType.getResult(i)) {
9023aa5a74SRiver Riddle       auto diag = emitOpError("result type mismatch at index ") << i;
9123aa5a74SRiver Riddle       diag.attachNote() << "      op result types: " << getResultTypes();
9223aa5a74SRiver Riddle       diag.attachNote() << "function result types: " << fnType.getResults();
9323aa5a74SRiver Riddle       return diag;
9423aa5a74SRiver Riddle     }
9523aa5a74SRiver Riddle 
9623aa5a74SRiver Riddle   return success();
9723aa5a74SRiver Riddle }
9823aa5a74SRiver Riddle 
9923aa5a74SRiver Riddle FunctionType CallOp::getCalleeType() {
10023aa5a74SRiver Riddle   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
10123aa5a74SRiver Riddle }
10223aa5a74SRiver Riddle 
10323aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
10423aa5a74SRiver Riddle // CallIndirectOp
10523aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
10623aa5a74SRiver Riddle 
10723aa5a74SRiver Riddle /// Fold indirect calls that have a constant function as the callee operand.
10823aa5a74SRiver Riddle LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
10923aa5a74SRiver Riddle                                            PatternRewriter &rewriter) {
11023aa5a74SRiver Riddle   // Check that the callee is a constant callee.
11123aa5a74SRiver Riddle   SymbolRefAttr calledFn;
11223aa5a74SRiver Riddle   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
11323aa5a74SRiver Riddle     return failure();
11423aa5a74SRiver Riddle 
11523aa5a74SRiver Riddle   // Replace with a direct call.
11623aa5a74SRiver Riddle   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
11723aa5a74SRiver Riddle                                       indirectCall.getResultTypes(),
11823aa5a74SRiver Riddle                                       indirectCall.getArgOperands());
11923aa5a74SRiver Riddle   return success();
12023aa5a74SRiver Riddle }
12123aa5a74SRiver Riddle 
12223aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
12323aa5a74SRiver Riddle // ConstantOp
12423aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
12523aa5a74SRiver Riddle 
126*663e9cecSArtem Kroviakov LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
12723aa5a74SRiver Riddle   StringRef fnName = getValue();
12823aa5a74SRiver Riddle   Type type = getType();
12923aa5a74SRiver Riddle 
13023aa5a74SRiver Riddle   // Try to find the referenced function.
131*663e9cecSArtem Kroviakov   auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
132*663e9cecSArtem Kroviakov       this->getOperation(), StringAttr::get(getContext(), fnName));
13323aa5a74SRiver Riddle   if (!fn)
13423aa5a74SRiver Riddle     return emitOpError() << "reference to undefined function '" << fnName
13523aa5a74SRiver Riddle                          << "'";
13623aa5a74SRiver Riddle 
13723aa5a74SRiver Riddle   // Check that the referenced function has the correct type.
1384a3460a7SRiver Riddle   if (fn.getFunctionType() != type)
13923aa5a74SRiver Riddle     return emitOpError("reference to function with mismatched type");
14023aa5a74SRiver Riddle 
14123aa5a74SRiver Riddle   return success();
14223aa5a74SRiver Riddle }
14323aa5a74SRiver Riddle 
1447df76121SMarkus Böck OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
14523aa5a74SRiver Riddle   return getValueAttr();
14623aa5a74SRiver Riddle }
14723aa5a74SRiver Riddle 
14823aa5a74SRiver Riddle void ConstantOp::getAsmResultNames(
14923aa5a74SRiver Riddle     function_ref<void(Value, StringRef)> setNameFn) {
15023aa5a74SRiver Riddle   setNameFn(getResult(), "f");
15123aa5a74SRiver Riddle }
15223aa5a74SRiver Riddle 
15323aa5a74SRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) {
154c1fa60b4STres Popp   return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
15523aa5a74SRiver Riddle }
15623aa5a74SRiver Riddle 
15723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
15836550692SRiver Riddle // FuncOp
15936550692SRiver Riddle //===----------------------------------------------------------------------===//
16036550692SRiver Riddle 
16136550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
16236550692SRiver Riddle                       ArrayRef<NamedAttribute> attrs) {
16336550692SRiver Riddle   OpBuilder builder(location->getContext());
16436550692SRiver Riddle   OperationState state(location, getOperationName());
16536550692SRiver Riddle   FuncOp::build(builder, state, name, type, attrs);
16636550692SRiver Riddle   return cast<FuncOp>(Operation::create(state));
16736550692SRiver Riddle }
16836550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
16936550692SRiver Riddle                       Operation::dialect_attr_range attrs) {
17036550692SRiver Riddle   SmallVector<NamedAttribute, 8> attrRef(attrs);
171984b800aSserge-sans-paille   return create(location, name, type, llvm::ArrayRef(attrRef));
17236550692SRiver Riddle }
17336550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
17436550692SRiver Riddle                       ArrayRef<NamedAttribute> attrs,
17536550692SRiver Riddle                       ArrayRef<DictionaryAttr> argAttrs) {
17636550692SRiver Riddle   FuncOp func = create(location, name, type, attrs);
17736550692SRiver Riddle   func.setAllArgAttrs(argAttrs);
17836550692SRiver Riddle   return func;
17936550692SRiver Riddle }
18036550692SRiver Riddle 
18136550692SRiver Riddle void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
18236550692SRiver Riddle                    FunctionType type, ArrayRef<NamedAttribute> attrs,
18336550692SRiver Riddle                    ArrayRef<DictionaryAttr> argAttrs) {
18436550692SRiver Riddle   state.addAttribute(SymbolTable::getSymbolAttrName(),
18536550692SRiver Riddle                      builder.getStringAttr(name));
18653406427SJeff Niu   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
18736550692SRiver Riddle   state.attributes.append(attrs.begin(), attrs.end());
18836550692SRiver Riddle   state.addRegion();
18936550692SRiver Riddle 
19036550692SRiver Riddle   if (argAttrs.empty())
19136550692SRiver Riddle     return;
19236550692SRiver Riddle   assert(type.getNumInputs() == argAttrs.size());
19353406427SJeff Niu   function_interface_impl::addArgAndResultAttrs(
19453406427SJeff Niu       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
19553406427SJeff Niu       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
19636550692SRiver Riddle }
19736550692SRiver Riddle 
19836550692SRiver Riddle ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
19936550692SRiver Riddle   auto buildFuncType =
20036550692SRiver Riddle       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
20136550692SRiver Riddle          function_interface_impl::VariadicFlag,
20236550692SRiver Riddle          std::string &) { return builder.getFunctionType(argTypes, results); };
20336550692SRiver Riddle 
20436550692SRiver Riddle   return function_interface_impl::parseFunctionOp(
20553406427SJeff Niu       parser, result, /*allowVariadic=*/false,
20653406427SJeff Niu       getFunctionTypeAttrName(result.name), buildFuncType,
20753406427SJeff Niu       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
20836550692SRiver Riddle }
20936550692SRiver Riddle 
21036550692SRiver Riddle void FuncOp::print(OpAsmPrinter &p) {
21153406427SJeff Niu   function_interface_impl::printFunctionOp(
21253406427SJeff Niu       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
21353406427SJeff Niu       getArgAttrsAttrName(), getResAttrsAttrName());
21436550692SRiver Riddle }
21536550692SRiver Riddle 
21636550692SRiver Riddle /// Clone the internal blocks from this function into dest and all attributes
21736550692SRiver Riddle /// from this function to dest.
2184d67b278SJeff Niu void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
21936550692SRiver Riddle   // Add the attributes of this function to dest.
22036550692SRiver Riddle   llvm::MapVector<StringAttr, Attribute> newAttrMap;
22136550692SRiver Riddle   for (const auto &attr : dest->getAttrs())
22236550692SRiver Riddle     newAttrMap.insert({attr.getName(), attr.getValue()});
22336550692SRiver Riddle   for (const auto &attr : (*this)->getAttrs())
22436550692SRiver Riddle     newAttrMap.insert({attr.getName(), attr.getValue()});
22536550692SRiver Riddle 
22636550692SRiver Riddle   auto newAttrs = llvm::to_vector(llvm::map_range(
22736550692SRiver Riddle       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
22836550692SRiver Riddle         return NamedAttribute(attrPair.first, attrPair.second);
22936550692SRiver Riddle       }));
23036550692SRiver Riddle   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
23136550692SRiver Riddle 
23236550692SRiver Riddle   // Clone the body.
23336550692SRiver Riddle   getBody().cloneInto(&dest.getBody(), mapper);
23436550692SRiver Riddle }
23536550692SRiver Riddle 
23636550692SRiver Riddle /// Create a deep copy of this function and all of its blocks, remapping
23736550692SRiver Riddle /// any operands that use values outside of the function using the map that is
23836550692SRiver Riddle /// provided (leaving them alone if no entry is present). Replaces references
23936550692SRiver Riddle /// to cloned sub-values with the corresponding value that is copied, and adds
24036550692SRiver Riddle /// those mappings to the mapper.
2414d67b278SJeff Niu FuncOp FuncOp::clone(IRMapping &mapper) {
24236550692SRiver Riddle   // Create the new function.
24336550692SRiver Riddle   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
24436550692SRiver Riddle 
24536550692SRiver Riddle   // If the function has a body, then the user might be deleting arguments to
24636550692SRiver Riddle   // the function by specifying them in the mapper. If so, we don't add the
24736550692SRiver Riddle   // argument to the input type vector.
24836550692SRiver Riddle   if (!isExternal()) {
2494a3460a7SRiver Riddle     FunctionType oldType = getFunctionType();
25036550692SRiver Riddle 
25136550692SRiver Riddle     unsigned oldNumArgs = oldType.getNumInputs();
25236550692SRiver Riddle     SmallVector<Type, 4> newInputs;
25336550692SRiver Riddle     newInputs.reserve(oldNumArgs);
25436550692SRiver Riddle     for (unsigned i = 0; i != oldNumArgs; ++i)
25536550692SRiver Riddle       if (!mapper.contains(getArgument(i)))
25636550692SRiver Riddle         newInputs.push_back(oldType.getInput(i));
25736550692SRiver Riddle 
25836550692SRiver Riddle     /// If any of the arguments were dropped, update the type and drop any
25936550692SRiver Riddle     /// necessary argument attributes.
26036550692SRiver Riddle     if (newInputs.size() != oldNumArgs) {
26136550692SRiver Riddle       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
26236550692SRiver Riddle                                         oldType.getResults()));
26336550692SRiver Riddle 
26436550692SRiver Riddle       if (ArrayAttr argAttrs = getAllArgAttrs()) {
26536550692SRiver Riddle         SmallVector<Attribute> newArgAttrs;
26636550692SRiver Riddle         newArgAttrs.reserve(newInputs.size());
26736550692SRiver Riddle         for (unsigned i = 0; i != oldNumArgs; ++i)
26836550692SRiver Riddle           if (!mapper.contains(getArgument(i)))
26936550692SRiver Riddle             newArgAttrs.push_back(argAttrs[i]);
27036550692SRiver Riddle         newFunc.setAllArgAttrs(newArgAttrs);
27136550692SRiver Riddle       }
27236550692SRiver Riddle     }
27336550692SRiver Riddle   }
27436550692SRiver Riddle 
27536550692SRiver Riddle   /// Clone the current function into the new one and return it.
27636550692SRiver Riddle   cloneInto(newFunc, mapper);
27736550692SRiver Riddle   return newFunc;
27836550692SRiver Riddle }
27936550692SRiver Riddle FuncOp FuncOp::clone() {
2804d67b278SJeff Niu   IRMapping mapper;
28136550692SRiver Riddle   return clone(mapper);
28236550692SRiver Riddle }
28336550692SRiver Riddle 
28436550692SRiver Riddle //===----------------------------------------------------------------------===//
28523aa5a74SRiver Riddle // ReturnOp
28623aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
28723aa5a74SRiver Riddle 
28823aa5a74SRiver Riddle LogicalResult ReturnOp::verify() {
28923aa5a74SRiver Riddle   auto function = cast<FuncOp>((*this)->getParentOp());
29023aa5a74SRiver Riddle 
29123aa5a74SRiver Riddle   // The operand number and types must match the function signature.
2924a3460a7SRiver Riddle   const auto &results = function.getFunctionType().getResults();
29323aa5a74SRiver Riddle   if (getNumOperands() != results.size())
29423aa5a74SRiver Riddle     return emitOpError("has ")
29523aa5a74SRiver Riddle            << getNumOperands() << " operands, but enclosing function (@"
29623aa5a74SRiver Riddle            << function.getName() << ") returns " << results.size();
29723aa5a74SRiver Riddle 
29823aa5a74SRiver Riddle   for (unsigned i = 0, e = results.size(); i != e; ++i)
29923aa5a74SRiver Riddle     if (getOperand(i).getType() != results[i])
30023aa5a74SRiver Riddle       return emitError() << "type of return operand " << i << " ("
30123aa5a74SRiver Riddle                          << getOperand(i).getType()
30223aa5a74SRiver Riddle                          << ") doesn't match function result type ("
30323aa5a74SRiver Riddle                          << results[i] << ")"
30423aa5a74SRiver Riddle                          << " in function @" << function.getName();
30523aa5a74SRiver Riddle 
30623aa5a74SRiver Riddle   return success();
30723aa5a74SRiver Riddle }
30823aa5a74SRiver Riddle 
30923aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
31023aa5a74SRiver Riddle // TableGen'd op method definitions
31123aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
31223aa5a74SRiver Riddle 
31323aa5a74SRiver Riddle #define GET_OP_CLASSES
31423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
315