1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 11 #include "mlir/IR/BuiltinOps.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/IRMapping.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/IR/Value.h" 19 #include "mlir/Interfaces/FunctionImplementation.h" 20 #include "mlir/Support/MathExtras.h" 21 #include "mlir/Transforms/InliningUtils.h" 22 #include "llvm/ADT/APFloat.h" 23 #include "llvm/ADT/MapVector.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/Support/raw_ostream.h" 27 #include <numeric> 28 29 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::func; 33 34 //===----------------------------------------------------------------------===// 35 // FuncDialect 36 //===----------------------------------------------------------------------===// 37 38 void FuncDialect::initialize() { 39 addOperations< 40 #define GET_OP_LIST 41 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 42 >(); 43 declarePromisedInterface<FuncDialect, DialectInlinerInterface>(); 44 } 45 46 /// Materialize a single constant operation from a given attribute value with 47 /// the desired resultant type. 48 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, 49 Type type, Location loc) { 50 if (ConstantOp::isBuildableWith(value, type)) 51 return builder.create<ConstantOp>(loc, type, 52 llvm::cast<FlatSymbolRefAttr>(value)); 53 return nullptr; 54 } 55 56 //===----------------------------------------------------------------------===// 57 // CallOp 58 //===----------------------------------------------------------------------===// 59 60 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 61 // Check that the callee attribute was specified. 62 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 63 if (!fnAttr) 64 return emitOpError("requires a 'callee' symbol reference attribute"); 65 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); 66 if (!fn) 67 return emitOpError() << "'" << fnAttr.getValue() 68 << "' does not reference a valid function"; 69 70 // Verify that the operand and result types match the callee. 71 auto fnType = fn.getFunctionType(); 72 if (fnType.getNumInputs() != getNumOperands()) 73 return emitOpError("incorrect number of operands for callee"); 74 75 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 76 if (getOperand(i).getType() != fnType.getInput(i)) 77 return emitOpError("operand type mismatch: expected operand type ") 78 << fnType.getInput(i) << ", but provided " 79 << getOperand(i).getType() << " for operand number " << i; 80 81 if (fnType.getNumResults() != getNumResults()) 82 return emitOpError("incorrect number of results for callee"); 83 84 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 85 if (getResult(i).getType() != fnType.getResult(i)) { 86 auto diag = emitOpError("result type mismatch at index ") << i; 87 diag.attachNote() << " op result types: " << getResultTypes(); 88 diag.attachNote() << "function result types: " << fnType.getResults(); 89 return diag; 90 } 91 92 return success(); 93 } 94 95 FunctionType CallOp::getCalleeType() { 96 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // CallIndirectOp 101 //===----------------------------------------------------------------------===// 102 103 /// Fold indirect calls that have a constant function as the callee operand. 104 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, 105 PatternRewriter &rewriter) { 106 // Check that the callee is a constant callee. 107 SymbolRefAttr calledFn; 108 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 109 return failure(); 110 111 // Replace with a direct call. 112 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, 113 indirectCall.getResultTypes(), 114 indirectCall.getArgOperands()); 115 return success(); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // ConstantOp 120 //===----------------------------------------------------------------------===// 121 122 LogicalResult ConstantOp::verify() { 123 StringRef fnName = getValue(); 124 Type type = getType(); 125 126 // Try to find the referenced function. 127 auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName); 128 if (!fn) 129 return emitOpError() << "reference to undefined function '" << fnName 130 << "'"; 131 132 // Check that the referenced function has the correct type. 133 if (fn.getFunctionType() != type) 134 return emitOpError("reference to function with mismatched type"); 135 136 return success(); 137 } 138 139 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { 140 return getValueAttr(); 141 } 142 143 void ConstantOp::getAsmResultNames( 144 function_ref<void(Value, StringRef)> setNameFn) { 145 setNameFn(getResult(), "f"); 146 } 147 148 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 149 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // FuncOp 154 //===----------------------------------------------------------------------===// 155 156 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 157 ArrayRef<NamedAttribute> attrs) { 158 OpBuilder builder(location->getContext()); 159 OperationState state(location, getOperationName()); 160 FuncOp::build(builder, state, name, type, attrs); 161 return cast<FuncOp>(Operation::create(state)); 162 } 163 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 164 Operation::dialect_attr_range attrs) { 165 SmallVector<NamedAttribute, 8> attrRef(attrs); 166 return create(location, name, type, llvm::ArrayRef(attrRef)); 167 } 168 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 169 ArrayRef<NamedAttribute> attrs, 170 ArrayRef<DictionaryAttr> argAttrs) { 171 FuncOp func = create(location, name, type, attrs); 172 func.setAllArgAttrs(argAttrs); 173 return func; 174 } 175 176 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 177 FunctionType type, ArrayRef<NamedAttribute> attrs, 178 ArrayRef<DictionaryAttr> argAttrs) { 179 state.addAttribute(SymbolTable::getSymbolAttrName(), 180 builder.getStringAttr(name)); 181 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); 182 state.attributes.append(attrs.begin(), attrs.end()); 183 state.addRegion(); 184 185 if (argAttrs.empty()) 186 return; 187 assert(type.getNumInputs() == argAttrs.size()); 188 function_interface_impl::addArgAndResultAttrs( 189 builder, state, argAttrs, /*resultAttrs=*/std::nullopt, 190 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); 191 } 192 193 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 194 auto buildFuncType = 195 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 196 function_interface_impl::VariadicFlag, 197 std::string &) { return builder.getFunctionType(argTypes, results); }; 198 199 return function_interface_impl::parseFunctionOp( 200 parser, result, /*allowVariadic=*/false, 201 getFunctionTypeAttrName(result.name), buildFuncType, 202 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 203 } 204 205 void FuncOp::print(OpAsmPrinter &p) { 206 function_interface_impl::printFunctionOp( 207 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 208 getArgAttrsAttrName(), getResAttrsAttrName()); 209 } 210 211 /// Clone the internal blocks from this function into dest and all attributes 212 /// from this function to dest. 213 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { 214 // Add the attributes of this function to dest. 215 llvm::MapVector<StringAttr, Attribute> newAttrMap; 216 for (const auto &attr : dest->getAttrs()) 217 newAttrMap.insert({attr.getName(), attr.getValue()}); 218 for (const auto &attr : (*this)->getAttrs()) 219 newAttrMap.insert({attr.getName(), attr.getValue()}); 220 221 auto newAttrs = llvm::to_vector(llvm::map_range( 222 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { 223 return NamedAttribute(attrPair.first, attrPair.second); 224 })); 225 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); 226 227 // Clone the body. 228 getBody().cloneInto(&dest.getBody(), mapper); 229 } 230 231 /// Create a deep copy of this function and all of its blocks, remapping 232 /// any operands that use values outside of the function using the map that is 233 /// provided (leaving them alone if no entry is present). Replaces references 234 /// to cloned sub-values with the corresponding value that is copied, and adds 235 /// those mappings to the mapper. 236 FuncOp FuncOp::clone(IRMapping &mapper) { 237 // Create the new function. 238 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 239 240 // If the function has a body, then the user might be deleting arguments to 241 // the function by specifying them in the mapper. If so, we don't add the 242 // argument to the input type vector. 243 if (!isExternal()) { 244 FunctionType oldType = getFunctionType(); 245 246 unsigned oldNumArgs = oldType.getNumInputs(); 247 SmallVector<Type, 4> newInputs; 248 newInputs.reserve(oldNumArgs); 249 for (unsigned i = 0; i != oldNumArgs; ++i) 250 if (!mapper.contains(getArgument(i))) 251 newInputs.push_back(oldType.getInput(i)); 252 253 /// If any of the arguments were dropped, update the type and drop any 254 /// necessary argument attributes. 255 if (newInputs.size() != oldNumArgs) { 256 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, 257 oldType.getResults())); 258 259 if (ArrayAttr argAttrs = getAllArgAttrs()) { 260 SmallVector<Attribute> newArgAttrs; 261 newArgAttrs.reserve(newInputs.size()); 262 for (unsigned i = 0; i != oldNumArgs; ++i) 263 if (!mapper.contains(getArgument(i))) 264 newArgAttrs.push_back(argAttrs[i]); 265 newFunc.setAllArgAttrs(newArgAttrs); 266 } 267 } 268 } 269 270 /// Clone the current function into the new one and return it. 271 cloneInto(newFunc, mapper); 272 return newFunc; 273 } 274 FuncOp FuncOp::clone() { 275 IRMapping mapper; 276 return clone(mapper); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // ReturnOp 281 //===----------------------------------------------------------------------===// 282 283 LogicalResult ReturnOp::verify() { 284 auto function = cast<FuncOp>((*this)->getParentOp()); 285 286 // The operand number and types must match the function signature. 287 const auto &results = function.getFunctionType().getResults(); 288 if (getNumOperands() != results.size()) 289 return emitOpError("has ") 290 << getNumOperands() << " operands, but enclosing function (@" 291 << function.getName() << ") returns " << results.size(); 292 293 for (unsigned i = 0, e = results.size(); i != e; ++i) 294 if (getOperand(i).getType() != results[i]) 295 return emitError() << "type of return operand " << i << " (" 296 << getOperand(i).getType() 297 << ") doesn't match function result type (" 298 << results[i] << ")" 299 << " in function @" << function.getName(); 300 301 return success(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // TableGen'd op method definitions 306 //===----------------------------------------------------------------------===// 307 308 #define GET_OP_CLASSES 309 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 310