xref: /llvm-project/mlir/lib/Dialect/Func/IR/FuncOps.cpp (revision d0e6fd99aa95ff61372ea328e9f89da2ee39c49c)
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "mlir/Interfaces/FunctionImplementation.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "llvm/ADT/APFloat.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <numeric>
28 
29 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::func;
33 
34 //===----------------------------------------------------------------------===//
35 // FuncDialect
36 //===----------------------------------------------------------------------===//
37 
38 void FuncDialect::initialize() {
39   addOperations<
40 #define GET_OP_LIST
41 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
42       >();
43   declarePromisedInterface<FuncDialect, DialectInlinerInterface>();
44 }
45 
46 /// Materialize a single constant operation from a given attribute value with
47 /// the desired resultant type.
48 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
49                                             Type type, Location loc) {
50   if (ConstantOp::isBuildableWith(value, type))
51     return builder.create<ConstantOp>(loc, type,
52                                       llvm::cast<FlatSymbolRefAttr>(value));
53   return nullptr;
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // CallOp
58 //===----------------------------------------------------------------------===//
59 
60 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
61   // Check that the callee attribute was specified.
62   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
63   if (!fnAttr)
64     return emitOpError("requires a 'callee' symbol reference attribute");
65   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
66   if (!fn)
67     return emitOpError() << "'" << fnAttr.getValue()
68                          << "' does not reference a valid function";
69 
70   // Verify that the operand and result types match the callee.
71   auto fnType = fn.getFunctionType();
72   if (fnType.getNumInputs() != getNumOperands())
73     return emitOpError("incorrect number of operands for callee");
74 
75   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
76     if (getOperand(i).getType() != fnType.getInput(i))
77       return emitOpError("operand type mismatch: expected operand type ")
78              << fnType.getInput(i) << ", but provided "
79              << getOperand(i).getType() << " for operand number " << i;
80 
81   if (fnType.getNumResults() != getNumResults())
82     return emitOpError("incorrect number of results for callee");
83 
84   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
85     if (getResult(i).getType() != fnType.getResult(i)) {
86       auto diag = emitOpError("result type mismatch at index ") << i;
87       diag.attachNote() << "      op result types: " << getResultTypes();
88       diag.attachNote() << "function result types: " << fnType.getResults();
89       return diag;
90     }
91 
92   return success();
93 }
94 
95 FunctionType CallOp::getCalleeType() {
96   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // CallIndirectOp
101 //===----------------------------------------------------------------------===//
102 
103 /// Fold indirect calls that have a constant function as the callee operand.
104 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
105                                            PatternRewriter &rewriter) {
106   // Check that the callee is a constant callee.
107   SymbolRefAttr calledFn;
108   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
109     return failure();
110 
111   // Replace with a direct call.
112   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
113                                       indirectCall.getResultTypes(),
114                                       indirectCall.getArgOperands());
115   return success();
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // ConstantOp
120 //===----------------------------------------------------------------------===//
121 
122 LogicalResult ConstantOp::verify() {
123   StringRef fnName = getValue();
124   Type type = getType();
125 
126   // Try to find the referenced function.
127   auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
128   if (!fn)
129     return emitOpError() << "reference to undefined function '" << fnName
130                          << "'";
131 
132   // Check that the referenced function has the correct type.
133   if (fn.getFunctionType() != type)
134     return emitOpError("reference to function with mismatched type");
135 
136   return success();
137 }
138 
139 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
140   return getValueAttr();
141 }
142 
143 void ConstantOp::getAsmResultNames(
144     function_ref<void(Value, StringRef)> setNameFn) {
145   setNameFn(getResult(), "f");
146 }
147 
148 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
149   return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // FuncOp
154 //===----------------------------------------------------------------------===//
155 
156 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
157                       ArrayRef<NamedAttribute> attrs) {
158   OpBuilder builder(location->getContext());
159   OperationState state(location, getOperationName());
160   FuncOp::build(builder, state, name, type, attrs);
161   return cast<FuncOp>(Operation::create(state));
162 }
163 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
164                       Operation::dialect_attr_range attrs) {
165   SmallVector<NamedAttribute, 8> attrRef(attrs);
166   return create(location, name, type, llvm::ArrayRef(attrRef));
167 }
168 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
169                       ArrayRef<NamedAttribute> attrs,
170                       ArrayRef<DictionaryAttr> argAttrs) {
171   FuncOp func = create(location, name, type, attrs);
172   func.setAllArgAttrs(argAttrs);
173   return func;
174 }
175 
176 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
177                    FunctionType type, ArrayRef<NamedAttribute> attrs,
178                    ArrayRef<DictionaryAttr> argAttrs) {
179   state.addAttribute(SymbolTable::getSymbolAttrName(),
180                      builder.getStringAttr(name));
181   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
182   state.attributes.append(attrs.begin(), attrs.end());
183   state.addRegion();
184 
185   if (argAttrs.empty())
186     return;
187   assert(type.getNumInputs() == argAttrs.size());
188   function_interface_impl::addArgAndResultAttrs(
189       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
190       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
191 }
192 
193 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
194   auto buildFuncType =
195       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
196          function_interface_impl::VariadicFlag,
197          std::string &) { return builder.getFunctionType(argTypes, results); };
198 
199   return function_interface_impl::parseFunctionOp(
200       parser, result, /*allowVariadic=*/false,
201       getFunctionTypeAttrName(result.name), buildFuncType,
202       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
203 }
204 
205 void FuncOp::print(OpAsmPrinter &p) {
206   function_interface_impl::printFunctionOp(
207       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
208       getArgAttrsAttrName(), getResAttrsAttrName());
209 }
210 
211 /// Clone the internal blocks from this function into dest and all attributes
212 /// from this function to dest.
213 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
214   // Add the attributes of this function to dest.
215   llvm::MapVector<StringAttr, Attribute> newAttrMap;
216   for (const auto &attr : dest->getAttrs())
217     newAttrMap.insert({attr.getName(), attr.getValue()});
218   for (const auto &attr : (*this)->getAttrs())
219     newAttrMap.insert({attr.getName(), attr.getValue()});
220 
221   auto newAttrs = llvm::to_vector(llvm::map_range(
222       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
223         return NamedAttribute(attrPair.first, attrPair.second);
224       }));
225   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
226 
227   // Clone the body.
228   getBody().cloneInto(&dest.getBody(), mapper);
229 }
230 
231 /// Create a deep copy of this function and all of its blocks, remapping
232 /// any operands that use values outside of the function using the map that is
233 /// provided (leaving them alone if no entry is present). Replaces references
234 /// to cloned sub-values with the corresponding value that is copied, and adds
235 /// those mappings to the mapper.
236 FuncOp FuncOp::clone(IRMapping &mapper) {
237   // Create the new function.
238   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
239 
240   // If the function has a body, then the user might be deleting arguments to
241   // the function by specifying them in the mapper. If so, we don't add the
242   // argument to the input type vector.
243   if (!isExternal()) {
244     FunctionType oldType = getFunctionType();
245 
246     unsigned oldNumArgs = oldType.getNumInputs();
247     SmallVector<Type, 4> newInputs;
248     newInputs.reserve(oldNumArgs);
249     for (unsigned i = 0; i != oldNumArgs; ++i)
250       if (!mapper.contains(getArgument(i)))
251         newInputs.push_back(oldType.getInput(i));
252 
253     /// If any of the arguments were dropped, update the type and drop any
254     /// necessary argument attributes.
255     if (newInputs.size() != oldNumArgs) {
256       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
257                                         oldType.getResults()));
258 
259       if (ArrayAttr argAttrs = getAllArgAttrs()) {
260         SmallVector<Attribute> newArgAttrs;
261         newArgAttrs.reserve(newInputs.size());
262         for (unsigned i = 0; i != oldNumArgs; ++i)
263           if (!mapper.contains(getArgument(i)))
264             newArgAttrs.push_back(argAttrs[i]);
265         newFunc.setAllArgAttrs(newArgAttrs);
266       }
267     }
268   }
269 
270   /// Clone the current function into the new one and return it.
271   cloneInto(newFunc, mapper);
272   return newFunc;
273 }
274 FuncOp FuncOp::clone() {
275   IRMapping mapper;
276   return clone(mapper);
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // ReturnOp
281 //===----------------------------------------------------------------------===//
282 
283 LogicalResult ReturnOp::verify() {
284   auto function = cast<FuncOp>((*this)->getParentOp());
285 
286   // The operand number and types must match the function signature.
287   const auto &results = function.getFunctionType().getResults();
288   if (getNumOperands() != results.size())
289     return emitOpError("has ")
290            << getNumOperands() << " operands, but enclosing function (@"
291            << function.getName() << ") returns " << results.size();
292 
293   for (unsigned i = 0, e = results.size(); i != e; ++i)
294     if (getOperand(i).getType() != results[i])
295       return emitError() << "type of return operand " << i << " ("
296                          << getOperand(i).getType()
297                          << ") doesn't match function result type ("
298                          << results[i] << ")"
299                          << " in function @" << function.getName();
300 
301   return success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // TableGen'd op method definitions
306 //===----------------------------------------------------------------------===//
307 
308 #define GET_OP_CLASSES
309 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
310