xref: /llvm-project/mlir/lib/Dialect/Func/IR/FuncOps.cpp (revision 663e9cec9c96169aa4e72ab9b6bf08b2d6603093)
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Interfaces/FunctionImplementation.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <numeric>
29 
30 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::func;
34 
35 //===----------------------------------------------------------------------===//
36 // FuncDialect
37 //===----------------------------------------------------------------------===//
38 
39 void FuncDialect::initialize() {
40   addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
43       >();
44   declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
45   declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
46   declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
47                             FuncOp, ReturnOp>();
48 }
49 
50 /// Materialize a single constant operation from a given attribute value with
51 /// the desired resultant type.
52 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
53                                             Type type, Location loc) {
54   if (ConstantOp::isBuildableWith(value, type))
55     return builder.create<ConstantOp>(loc, type,
56                                       llvm::cast<FlatSymbolRefAttr>(value));
57   return nullptr;
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // CallOp
62 //===----------------------------------------------------------------------===//
63 
64 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
65   // Check that the callee attribute was specified.
66   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
67   if (!fnAttr)
68     return emitOpError("requires a 'callee' symbol reference attribute");
69   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
70   if (!fn)
71     return emitOpError() << "'" << fnAttr.getValue()
72                          << "' does not reference a valid function";
73 
74   // Verify that the operand and result types match the callee.
75   auto fnType = fn.getFunctionType();
76   if (fnType.getNumInputs() != getNumOperands())
77     return emitOpError("incorrect number of operands for callee");
78 
79   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
80     if (getOperand(i).getType() != fnType.getInput(i))
81       return emitOpError("operand type mismatch: expected operand type ")
82              << fnType.getInput(i) << ", but provided "
83              << getOperand(i).getType() << " for operand number " << i;
84 
85   if (fnType.getNumResults() != getNumResults())
86     return emitOpError("incorrect number of results for callee");
87 
88   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
89     if (getResult(i).getType() != fnType.getResult(i)) {
90       auto diag = emitOpError("result type mismatch at index ") << i;
91       diag.attachNote() << "      op result types: " << getResultTypes();
92       diag.attachNote() << "function result types: " << fnType.getResults();
93       return diag;
94     }
95 
96   return success();
97 }
98 
99 FunctionType CallOp::getCalleeType() {
100   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // CallIndirectOp
105 //===----------------------------------------------------------------------===//
106 
107 /// Fold indirect calls that have a constant function as the callee operand.
108 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
109                                            PatternRewriter &rewriter) {
110   // Check that the callee is a constant callee.
111   SymbolRefAttr calledFn;
112   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
113     return failure();
114 
115   // Replace with a direct call.
116   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
117                                       indirectCall.getResultTypes(),
118                                       indirectCall.getArgOperands());
119   return success();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // ConstantOp
124 //===----------------------------------------------------------------------===//
125 
126 LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
127   StringRef fnName = getValue();
128   Type type = getType();
129 
130   // Try to find the referenced function.
131   auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
132       this->getOperation(), StringAttr::get(getContext(), fnName));
133   if (!fn)
134     return emitOpError() << "reference to undefined function '" << fnName
135                          << "'";
136 
137   // Check that the referenced function has the correct type.
138   if (fn.getFunctionType() != type)
139     return emitOpError("reference to function with mismatched type");
140 
141   return success();
142 }
143 
144 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
145   return getValueAttr();
146 }
147 
148 void ConstantOp::getAsmResultNames(
149     function_ref<void(Value, StringRef)> setNameFn) {
150   setNameFn(getResult(), "f");
151 }
152 
153 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
154   return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // FuncOp
159 //===----------------------------------------------------------------------===//
160 
161 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
162                       ArrayRef<NamedAttribute> attrs) {
163   OpBuilder builder(location->getContext());
164   OperationState state(location, getOperationName());
165   FuncOp::build(builder, state, name, type, attrs);
166   return cast<FuncOp>(Operation::create(state));
167 }
168 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
169                       Operation::dialect_attr_range attrs) {
170   SmallVector<NamedAttribute, 8> attrRef(attrs);
171   return create(location, name, type, llvm::ArrayRef(attrRef));
172 }
173 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
174                       ArrayRef<NamedAttribute> attrs,
175                       ArrayRef<DictionaryAttr> argAttrs) {
176   FuncOp func = create(location, name, type, attrs);
177   func.setAllArgAttrs(argAttrs);
178   return func;
179 }
180 
181 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
182                    FunctionType type, ArrayRef<NamedAttribute> attrs,
183                    ArrayRef<DictionaryAttr> argAttrs) {
184   state.addAttribute(SymbolTable::getSymbolAttrName(),
185                      builder.getStringAttr(name));
186   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
187   state.attributes.append(attrs.begin(), attrs.end());
188   state.addRegion();
189 
190   if (argAttrs.empty())
191     return;
192   assert(type.getNumInputs() == argAttrs.size());
193   function_interface_impl::addArgAndResultAttrs(
194       builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
195       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
196 }
197 
198 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
199   auto buildFuncType =
200       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
201          function_interface_impl::VariadicFlag,
202          std::string &) { return builder.getFunctionType(argTypes, results); };
203 
204   return function_interface_impl::parseFunctionOp(
205       parser, result, /*allowVariadic=*/false,
206       getFunctionTypeAttrName(result.name), buildFuncType,
207       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
208 }
209 
210 void FuncOp::print(OpAsmPrinter &p) {
211   function_interface_impl::printFunctionOp(
212       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
213       getArgAttrsAttrName(), getResAttrsAttrName());
214 }
215 
216 /// Clone the internal blocks from this function into dest and all attributes
217 /// from this function to dest.
218 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
219   // Add the attributes of this function to dest.
220   llvm::MapVector<StringAttr, Attribute> newAttrMap;
221   for (const auto &attr : dest->getAttrs())
222     newAttrMap.insert({attr.getName(), attr.getValue()});
223   for (const auto &attr : (*this)->getAttrs())
224     newAttrMap.insert({attr.getName(), attr.getValue()});
225 
226   auto newAttrs = llvm::to_vector(llvm::map_range(
227       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
228         return NamedAttribute(attrPair.first, attrPair.second);
229       }));
230   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
231 
232   // Clone the body.
233   getBody().cloneInto(&dest.getBody(), mapper);
234 }
235 
236 /// Create a deep copy of this function and all of its blocks, remapping
237 /// any operands that use values outside of the function using the map that is
238 /// provided (leaving them alone if no entry is present). Replaces references
239 /// to cloned sub-values with the corresponding value that is copied, and adds
240 /// those mappings to the mapper.
241 FuncOp FuncOp::clone(IRMapping &mapper) {
242   // Create the new function.
243   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
244 
245   // If the function has a body, then the user might be deleting arguments to
246   // the function by specifying them in the mapper. If so, we don't add the
247   // argument to the input type vector.
248   if (!isExternal()) {
249     FunctionType oldType = getFunctionType();
250 
251     unsigned oldNumArgs = oldType.getNumInputs();
252     SmallVector<Type, 4> newInputs;
253     newInputs.reserve(oldNumArgs);
254     for (unsigned i = 0; i != oldNumArgs; ++i)
255       if (!mapper.contains(getArgument(i)))
256         newInputs.push_back(oldType.getInput(i));
257 
258     /// If any of the arguments were dropped, update the type and drop any
259     /// necessary argument attributes.
260     if (newInputs.size() != oldNumArgs) {
261       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
262                                         oldType.getResults()));
263 
264       if (ArrayAttr argAttrs = getAllArgAttrs()) {
265         SmallVector<Attribute> newArgAttrs;
266         newArgAttrs.reserve(newInputs.size());
267         for (unsigned i = 0; i != oldNumArgs; ++i)
268           if (!mapper.contains(getArgument(i)))
269             newArgAttrs.push_back(argAttrs[i]);
270         newFunc.setAllArgAttrs(newArgAttrs);
271       }
272     }
273   }
274 
275   /// Clone the current function into the new one and return it.
276   cloneInto(newFunc, mapper);
277   return newFunc;
278 }
279 FuncOp FuncOp::clone() {
280   IRMapping mapper;
281   return clone(mapper);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // ReturnOp
286 //===----------------------------------------------------------------------===//
287 
288 LogicalResult ReturnOp::verify() {
289   auto function = cast<FuncOp>((*this)->getParentOp());
290 
291   // The operand number and types must match the function signature.
292   const auto &results = function.getFunctionType().getResults();
293   if (getNumOperands() != results.size())
294     return emitOpError("has ")
295            << getNumOperands() << " operands, but enclosing function (@"
296            << function.getName() << ") returns " << results.size();
297 
298   for (unsigned i = 0, e = results.size(); i != e; ++i)
299     if (getOperand(i).getType() != results[i])
300       return emitError() << "type of return operand " << i << " ("
301                          << getOperand(i).getType()
302                          << ") doesn't match function result type ("
303                          << results[i] << ")"
304                          << " in function @" << function.getName();
305 
306   return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // TableGen'd op method definitions
311 //===----------------------------------------------------------------------===//
312 
313 #define GET_OP_CLASSES
314 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
315