1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/IRMapping.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/IR/Value.h" 21 #include "mlir/Interfaces/FunctionImplementation.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/APFloat.h" 24 #include "llvm/ADT/MapVector.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <numeric> 29 30 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::func; 34 35 //===----------------------------------------------------------------------===// 36 // FuncDialect 37 //===----------------------------------------------------------------------===// 38 39 void FuncDialect::initialize() { 40 addOperations< 41 #define GET_OP_LIST 42 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 43 >(); 44 declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); 45 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); 46 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, 47 FuncOp, ReturnOp>(); 48 } 49 50 /// Materialize a single constant operation from a given attribute value with 51 /// the desired resultant type. 52 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, 53 Type type, Location loc) { 54 if (ConstantOp::isBuildableWith(value, type)) 55 return builder.create<ConstantOp>(loc, type, 56 llvm::cast<FlatSymbolRefAttr>(value)); 57 return nullptr; 58 } 59 60 //===----------------------------------------------------------------------===// 61 // CallOp 62 //===----------------------------------------------------------------------===// 63 64 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 65 // Check that the callee attribute was specified. 66 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 67 if (!fnAttr) 68 return emitOpError("requires a 'callee' symbol reference attribute"); 69 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); 70 if (!fn) 71 return emitOpError() << "'" << fnAttr.getValue() 72 << "' does not reference a valid function"; 73 74 // Verify that the operand and result types match the callee. 75 auto fnType = fn.getFunctionType(); 76 if (fnType.getNumInputs() != getNumOperands()) 77 return emitOpError("incorrect number of operands for callee"); 78 79 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 80 if (getOperand(i).getType() != fnType.getInput(i)) 81 return emitOpError("operand type mismatch: expected operand type ") 82 << fnType.getInput(i) << ", but provided " 83 << getOperand(i).getType() << " for operand number " << i; 84 85 if (fnType.getNumResults() != getNumResults()) 86 return emitOpError("incorrect number of results for callee"); 87 88 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 89 if (getResult(i).getType() != fnType.getResult(i)) { 90 auto diag = emitOpError("result type mismatch at index ") << i; 91 diag.attachNote() << " op result types: " << getResultTypes(); 92 diag.attachNote() << "function result types: " << fnType.getResults(); 93 return diag; 94 } 95 96 return success(); 97 } 98 99 FunctionType CallOp::getCalleeType() { 100 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // CallIndirectOp 105 //===----------------------------------------------------------------------===// 106 107 /// Fold indirect calls that have a constant function as the callee operand. 108 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, 109 PatternRewriter &rewriter) { 110 // Check that the callee is a constant callee. 111 SymbolRefAttr calledFn; 112 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 113 return failure(); 114 115 // Replace with a direct call. 116 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, 117 indirectCall.getResultTypes(), 118 indirectCall.getArgOperands()); 119 return success(); 120 } 121 122 //===----------------------------------------------------------------------===// 123 // ConstantOp 124 //===----------------------------------------------------------------------===// 125 126 LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 127 StringRef fnName = getValue(); 128 Type type = getType(); 129 130 // Try to find the referenced function. 131 auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>( 132 this->getOperation(), StringAttr::get(getContext(), fnName)); 133 if (!fn) 134 return emitOpError() << "reference to undefined function '" << fnName 135 << "'"; 136 137 // Check that the referenced function has the correct type. 138 if (fn.getFunctionType() != type) 139 return emitOpError("reference to function with mismatched type"); 140 141 return success(); 142 } 143 144 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { 145 return getValueAttr(); 146 } 147 148 void ConstantOp::getAsmResultNames( 149 function_ref<void(Value, StringRef)> setNameFn) { 150 setNameFn(getResult(), "f"); 151 } 152 153 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 154 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // FuncOp 159 //===----------------------------------------------------------------------===// 160 161 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 162 ArrayRef<NamedAttribute> attrs) { 163 OpBuilder builder(location->getContext()); 164 OperationState state(location, getOperationName()); 165 FuncOp::build(builder, state, name, type, attrs); 166 return cast<FuncOp>(Operation::create(state)); 167 } 168 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 169 Operation::dialect_attr_range attrs) { 170 SmallVector<NamedAttribute, 8> attrRef(attrs); 171 return create(location, name, type, llvm::ArrayRef(attrRef)); 172 } 173 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 174 ArrayRef<NamedAttribute> attrs, 175 ArrayRef<DictionaryAttr> argAttrs) { 176 FuncOp func = create(location, name, type, attrs); 177 func.setAllArgAttrs(argAttrs); 178 return func; 179 } 180 181 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 182 FunctionType type, ArrayRef<NamedAttribute> attrs, 183 ArrayRef<DictionaryAttr> argAttrs) { 184 state.addAttribute(SymbolTable::getSymbolAttrName(), 185 builder.getStringAttr(name)); 186 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); 187 state.attributes.append(attrs.begin(), attrs.end()); 188 state.addRegion(); 189 190 if (argAttrs.empty()) 191 return; 192 assert(type.getNumInputs() == argAttrs.size()); 193 function_interface_impl::addArgAndResultAttrs( 194 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, 195 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); 196 } 197 198 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 199 auto buildFuncType = 200 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 201 function_interface_impl::VariadicFlag, 202 std::string &) { return builder.getFunctionType(argTypes, results); }; 203 204 return function_interface_impl::parseFunctionOp( 205 parser, result, /*allowVariadic=*/false, 206 getFunctionTypeAttrName(result.name), buildFuncType, 207 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 208 } 209 210 void FuncOp::print(OpAsmPrinter &p) { 211 function_interface_impl::printFunctionOp( 212 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 213 getArgAttrsAttrName(), getResAttrsAttrName()); 214 } 215 216 /// Clone the internal blocks from this function into dest and all attributes 217 /// from this function to dest. 218 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { 219 // Add the attributes of this function to dest. 220 llvm::MapVector<StringAttr, Attribute> newAttrMap; 221 for (const auto &attr : dest->getAttrs()) 222 newAttrMap.insert({attr.getName(), attr.getValue()}); 223 for (const auto &attr : (*this)->getAttrs()) 224 newAttrMap.insert({attr.getName(), attr.getValue()}); 225 226 auto newAttrs = llvm::to_vector(llvm::map_range( 227 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { 228 return NamedAttribute(attrPair.first, attrPair.second); 229 })); 230 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); 231 232 // Clone the body. 233 getBody().cloneInto(&dest.getBody(), mapper); 234 } 235 236 /// Create a deep copy of this function and all of its blocks, remapping 237 /// any operands that use values outside of the function using the map that is 238 /// provided (leaving them alone if no entry is present). Replaces references 239 /// to cloned sub-values with the corresponding value that is copied, and adds 240 /// those mappings to the mapper. 241 FuncOp FuncOp::clone(IRMapping &mapper) { 242 // Create the new function. 243 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 244 245 // If the function has a body, then the user might be deleting arguments to 246 // the function by specifying them in the mapper. If so, we don't add the 247 // argument to the input type vector. 248 if (!isExternal()) { 249 FunctionType oldType = getFunctionType(); 250 251 unsigned oldNumArgs = oldType.getNumInputs(); 252 SmallVector<Type, 4> newInputs; 253 newInputs.reserve(oldNumArgs); 254 for (unsigned i = 0; i != oldNumArgs; ++i) 255 if (!mapper.contains(getArgument(i))) 256 newInputs.push_back(oldType.getInput(i)); 257 258 /// If any of the arguments were dropped, update the type and drop any 259 /// necessary argument attributes. 260 if (newInputs.size() != oldNumArgs) { 261 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, 262 oldType.getResults())); 263 264 if (ArrayAttr argAttrs = getAllArgAttrs()) { 265 SmallVector<Attribute> newArgAttrs; 266 newArgAttrs.reserve(newInputs.size()); 267 for (unsigned i = 0; i != oldNumArgs; ++i) 268 if (!mapper.contains(getArgument(i))) 269 newArgAttrs.push_back(argAttrs[i]); 270 newFunc.setAllArgAttrs(newArgAttrs); 271 } 272 } 273 } 274 275 /// Clone the current function into the new one and return it. 276 cloneInto(newFunc, mapper); 277 return newFunc; 278 } 279 FuncOp FuncOp::clone() { 280 IRMapping mapper; 281 return clone(mapper); 282 } 283 284 //===----------------------------------------------------------------------===// 285 // ReturnOp 286 //===----------------------------------------------------------------------===// 287 288 LogicalResult ReturnOp::verify() { 289 auto function = cast<FuncOp>((*this)->getParentOp()); 290 291 // The operand number and types must match the function signature. 292 const auto &results = function.getFunctionType().getResults(); 293 if (getNumOperands() != results.size()) 294 return emitOpError("has ") 295 << getNumOperands() << " operands, but enclosing function (@" 296 << function.getName() << ") returns " << results.size(); 297 298 for (unsigned i = 0, e = results.size(); i != e; ++i) 299 if (getOperand(i).getType() != results[i]) 300 return emitError() << "type of return operand " << i << " (" 301 << getOperand(i).getType() 302 << ") doesn't match function result type (" 303 << results[i] << ")" 304 << " in function @" << function.getName(); 305 306 return success(); 307 } 308 309 //===----------------------------------------------------------------------===// 310 // TableGen'd op method definitions 311 //===----------------------------------------------------------------------===// 312 313 #define GET_OP_CLASSES 314 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 315