123aa5a74SRiver Riddle //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// 223aa5a74SRiver Riddle // 323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 623aa5a74SRiver Riddle // 723aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 823aa5a74SRiver Riddle 923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1023aa5a74SRiver Riddle 11b43c5049SJustin Fargnoli #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 12513cdb82SJustin Fargnoli #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1323aa5a74SRiver Riddle #include "mlir/IR/BuiltinOps.h" 1423aa5a74SRiver Riddle #include "mlir/IR/BuiltinTypes.h" 154d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 1623aa5a74SRiver Riddle #include "mlir/IR/Matchers.h" 1723aa5a74SRiver Riddle #include "mlir/IR/OpImplementation.h" 1823aa5a74SRiver Riddle #include "mlir/IR/PatternMatch.h" 1923aa5a74SRiver Riddle #include "mlir/IR/TypeUtilities.h" 2023aa5a74SRiver Riddle #include "mlir/IR/Value.h" 2134a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionImplementation.h" 2223aa5a74SRiver Riddle #include "mlir/Transforms/InliningUtils.h" 2323aa5a74SRiver Riddle #include "llvm/ADT/APFloat.h" 2436550692SRiver Riddle #include "llvm/ADT/MapVector.h" 2523aa5a74SRiver Riddle #include "llvm/ADT/STLExtras.h" 2623aa5a74SRiver Riddle #include "llvm/Support/FormatVariadic.h" 2723aa5a74SRiver Riddle #include "llvm/Support/raw_ostream.h" 2823aa5a74SRiver Riddle #include <numeric> 2923aa5a74SRiver Riddle 3023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" 3123aa5a74SRiver Riddle 3223aa5a74SRiver Riddle using namespace mlir; 3323aa5a74SRiver Riddle using namespace mlir::func; 3423aa5a74SRiver Riddle 3523aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 3623aa5a74SRiver Riddle // FuncDialect 3723aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 3823aa5a74SRiver Riddle 3923aa5a74SRiver Riddle void FuncDialect::initialize() { 4023aa5a74SRiver Riddle addOperations< 4123aa5a74SRiver Riddle #define GET_OP_LIST 4223aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 4323aa5a74SRiver Riddle >(); 4435d55f28SJustin Fargnoli declarePromisedInterface<DialectInlinerInterface, FuncDialect>(); 4535d55f28SJustin Fargnoli declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>(); 46513cdb82SJustin Fargnoli declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp, 47513cdb82SJustin Fargnoli FuncOp, ReturnOp>(); 4823aa5a74SRiver Riddle } 4923aa5a74SRiver Riddle 5023aa5a74SRiver Riddle /// Materialize a single constant operation from a given attribute value with 5123aa5a74SRiver Riddle /// the desired resultant type. 5223aa5a74SRiver Riddle Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, 5323aa5a74SRiver Riddle Type type, Location loc) { 5423aa5a74SRiver Riddle if (ConstantOp::isBuildableWith(value, type)) 5523aa5a74SRiver Riddle return builder.create<ConstantOp>(loc, type, 56c1fa60b4STres Popp llvm::cast<FlatSymbolRefAttr>(value)); 5723aa5a74SRiver Riddle return nullptr; 5823aa5a74SRiver Riddle } 5923aa5a74SRiver Riddle 6023aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 6123aa5a74SRiver Riddle // CallOp 6223aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 6323aa5a74SRiver Riddle 6423aa5a74SRiver Riddle LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 6523aa5a74SRiver Riddle // Check that the callee attribute was specified. 6623aa5a74SRiver Riddle auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 6723aa5a74SRiver Riddle if (!fnAttr) 6823aa5a74SRiver Riddle return emitOpError("requires a 'callee' symbol reference attribute"); 6923aa5a74SRiver Riddle FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); 7023aa5a74SRiver Riddle if (!fn) 7123aa5a74SRiver Riddle return emitOpError() << "'" << fnAttr.getValue() 7223aa5a74SRiver Riddle << "' does not reference a valid function"; 7323aa5a74SRiver Riddle 7423aa5a74SRiver Riddle // Verify that the operand and result types match the callee. 754a3460a7SRiver Riddle auto fnType = fn.getFunctionType(); 7623aa5a74SRiver Riddle if (fnType.getNumInputs() != getNumOperands()) 7723aa5a74SRiver Riddle return emitOpError("incorrect number of operands for callee"); 7823aa5a74SRiver Riddle 7923aa5a74SRiver Riddle for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 8023aa5a74SRiver Riddle if (getOperand(i).getType() != fnType.getInput(i)) 8123aa5a74SRiver Riddle return emitOpError("operand type mismatch: expected operand type ") 8223aa5a74SRiver Riddle << fnType.getInput(i) << ", but provided " 8323aa5a74SRiver Riddle << getOperand(i).getType() << " for operand number " << i; 8423aa5a74SRiver Riddle 8523aa5a74SRiver Riddle if (fnType.getNumResults() != getNumResults()) 8623aa5a74SRiver Riddle return emitOpError("incorrect number of results for callee"); 8723aa5a74SRiver Riddle 8823aa5a74SRiver Riddle for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 8923aa5a74SRiver Riddle if (getResult(i).getType() != fnType.getResult(i)) { 9023aa5a74SRiver Riddle auto diag = emitOpError("result type mismatch at index ") << i; 9123aa5a74SRiver Riddle diag.attachNote() << " op result types: " << getResultTypes(); 9223aa5a74SRiver Riddle diag.attachNote() << "function result types: " << fnType.getResults(); 9323aa5a74SRiver Riddle return diag; 9423aa5a74SRiver Riddle } 9523aa5a74SRiver Riddle 9623aa5a74SRiver Riddle return success(); 9723aa5a74SRiver Riddle } 9823aa5a74SRiver Riddle 9923aa5a74SRiver Riddle FunctionType CallOp::getCalleeType() { 10023aa5a74SRiver Riddle return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); 10123aa5a74SRiver Riddle } 10223aa5a74SRiver Riddle 10323aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 10423aa5a74SRiver Riddle // CallIndirectOp 10523aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 10623aa5a74SRiver Riddle 10723aa5a74SRiver Riddle /// Fold indirect calls that have a constant function as the callee operand. 10823aa5a74SRiver Riddle LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, 10923aa5a74SRiver Riddle PatternRewriter &rewriter) { 11023aa5a74SRiver Riddle // Check that the callee is a constant callee. 11123aa5a74SRiver Riddle SymbolRefAttr calledFn; 11223aa5a74SRiver Riddle if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 11323aa5a74SRiver Riddle return failure(); 11423aa5a74SRiver Riddle 11523aa5a74SRiver Riddle // Replace with a direct call. 11623aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, 11723aa5a74SRiver Riddle indirectCall.getResultTypes(), 11823aa5a74SRiver Riddle indirectCall.getArgOperands()); 11923aa5a74SRiver Riddle return success(); 12023aa5a74SRiver Riddle } 12123aa5a74SRiver Riddle 12223aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 12323aa5a74SRiver Riddle // ConstantOp 12423aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 12523aa5a74SRiver Riddle 126*663e9cecSArtem Kroviakov LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 12723aa5a74SRiver Riddle StringRef fnName = getValue(); 12823aa5a74SRiver Riddle Type type = getType(); 12923aa5a74SRiver Riddle 13023aa5a74SRiver Riddle // Try to find the referenced function. 131*663e9cecSArtem Kroviakov auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>( 132*663e9cecSArtem Kroviakov this->getOperation(), StringAttr::get(getContext(), fnName)); 13323aa5a74SRiver Riddle if (!fn) 13423aa5a74SRiver Riddle return emitOpError() << "reference to undefined function '" << fnName 13523aa5a74SRiver Riddle << "'"; 13623aa5a74SRiver Riddle 13723aa5a74SRiver Riddle // Check that the referenced function has the correct type. 1384a3460a7SRiver Riddle if (fn.getFunctionType() != type) 13923aa5a74SRiver Riddle return emitOpError("reference to function with mismatched type"); 14023aa5a74SRiver Riddle 14123aa5a74SRiver Riddle return success(); 14223aa5a74SRiver Riddle } 14323aa5a74SRiver Riddle 1447df76121SMarkus Böck OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { 14523aa5a74SRiver Riddle return getValueAttr(); 14623aa5a74SRiver Riddle } 14723aa5a74SRiver Riddle 14823aa5a74SRiver Riddle void ConstantOp::getAsmResultNames( 14923aa5a74SRiver Riddle function_ref<void(Value, StringRef)> setNameFn) { 15023aa5a74SRiver Riddle setNameFn(getResult(), "f"); 15123aa5a74SRiver Riddle } 15223aa5a74SRiver Riddle 15323aa5a74SRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) { 154c1fa60b4STres Popp return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type); 15523aa5a74SRiver Riddle } 15623aa5a74SRiver Riddle 15723aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 15836550692SRiver Riddle // FuncOp 15936550692SRiver Riddle //===----------------------------------------------------------------------===// 16036550692SRiver Riddle 16136550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 16236550692SRiver Riddle ArrayRef<NamedAttribute> attrs) { 16336550692SRiver Riddle OpBuilder builder(location->getContext()); 16436550692SRiver Riddle OperationState state(location, getOperationName()); 16536550692SRiver Riddle FuncOp::build(builder, state, name, type, attrs); 16636550692SRiver Riddle return cast<FuncOp>(Operation::create(state)); 16736550692SRiver Riddle } 16836550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 16936550692SRiver Riddle Operation::dialect_attr_range attrs) { 17036550692SRiver Riddle SmallVector<NamedAttribute, 8> attrRef(attrs); 171984b800aSserge-sans-paille return create(location, name, type, llvm::ArrayRef(attrRef)); 17236550692SRiver Riddle } 17336550692SRiver Riddle FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 17436550692SRiver Riddle ArrayRef<NamedAttribute> attrs, 17536550692SRiver Riddle ArrayRef<DictionaryAttr> argAttrs) { 17636550692SRiver Riddle FuncOp func = create(location, name, type, attrs); 17736550692SRiver Riddle func.setAllArgAttrs(argAttrs); 17836550692SRiver Riddle return func; 17936550692SRiver Riddle } 18036550692SRiver Riddle 18136550692SRiver Riddle void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 18236550692SRiver Riddle FunctionType type, ArrayRef<NamedAttribute> attrs, 18336550692SRiver Riddle ArrayRef<DictionaryAttr> argAttrs) { 18436550692SRiver Riddle state.addAttribute(SymbolTable::getSymbolAttrName(), 18536550692SRiver Riddle builder.getStringAttr(name)); 18653406427SJeff Niu state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); 18736550692SRiver Riddle state.attributes.append(attrs.begin(), attrs.end()); 18836550692SRiver Riddle state.addRegion(); 18936550692SRiver Riddle 19036550692SRiver Riddle if (argAttrs.empty()) 19136550692SRiver Riddle return; 19236550692SRiver Riddle assert(type.getNumInputs() == argAttrs.size()); 19353406427SJeff Niu function_interface_impl::addArgAndResultAttrs( 19453406427SJeff Niu builder, state, argAttrs, /*resultAttrs=*/std::nullopt, 19553406427SJeff Niu getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); 19636550692SRiver Riddle } 19736550692SRiver Riddle 19836550692SRiver Riddle ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 19936550692SRiver Riddle auto buildFuncType = 20036550692SRiver Riddle [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 20136550692SRiver Riddle function_interface_impl::VariadicFlag, 20236550692SRiver Riddle std::string &) { return builder.getFunctionType(argTypes, results); }; 20336550692SRiver Riddle 20436550692SRiver Riddle return function_interface_impl::parseFunctionOp( 20553406427SJeff Niu parser, result, /*allowVariadic=*/false, 20653406427SJeff Niu getFunctionTypeAttrName(result.name), buildFuncType, 20753406427SJeff Niu getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 20836550692SRiver Riddle } 20936550692SRiver Riddle 21036550692SRiver Riddle void FuncOp::print(OpAsmPrinter &p) { 21153406427SJeff Niu function_interface_impl::printFunctionOp( 21253406427SJeff Niu p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 21353406427SJeff Niu getArgAttrsAttrName(), getResAttrsAttrName()); 21436550692SRiver Riddle } 21536550692SRiver Riddle 21636550692SRiver Riddle /// Clone the internal blocks from this function into dest and all attributes 21736550692SRiver Riddle /// from this function to dest. 2184d67b278SJeff Niu void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { 21936550692SRiver Riddle // Add the attributes of this function to dest. 22036550692SRiver Riddle llvm::MapVector<StringAttr, Attribute> newAttrMap; 22136550692SRiver Riddle for (const auto &attr : dest->getAttrs()) 22236550692SRiver Riddle newAttrMap.insert({attr.getName(), attr.getValue()}); 22336550692SRiver Riddle for (const auto &attr : (*this)->getAttrs()) 22436550692SRiver Riddle newAttrMap.insert({attr.getName(), attr.getValue()}); 22536550692SRiver Riddle 22636550692SRiver Riddle auto newAttrs = llvm::to_vector(llvm::map_range( 22736550692SRiver Riddle newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { 22836550692SRiver Riddle return NamedAttribute(attrPair.first, attrPair.second); 22936550692SRiver Riddle })); 23036550692SRiver Riddle dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); 23136550692SRiver Riddle 23236550692SRiver Riddle // Clone the body. 23336550692SRiver Riddle getBody().cloneInto(&dest.getBody(), mapper); 23436550692SRiver Riddle } 23536550692SRiver Riddle 23636550692SRiver Riddle /// Create a deep copy of this function and all of its blocks, remapping 23736550692SRiver Riddle /// any operands that use values outside of the function using the map that is 23836550692SRiver Riddle /// provided (leaving them alone if no entry is present). Replaces references 23936550692SRiver Riddle /// to cloned sub-values with the corresponding value that is copied, and adds 24036550692SRiver Riddle /// those mappings to the mapper. 2414d67b278SJeff Niu FuncOp FuncOp::clone(IRMapping &mapper) { 24236550692SRiver Riddle // Create the new function. 24336550692SRiver Riddle FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 24436550692SRiver Riddle 24536550692SRiver Riddle // If the function has a body, then the user might be deleting arguments to 24636550692SRiver Riddle // the function by specifying them in the mapper. If so, we don't add the 24736550692SRiver Riddle // argument to the input type vector. 24836550692SRiver Riddle if (!isExternal()) { 2494a3460a7SRiver Riddle FunctionType oldType = getFunctionType(); 25036550692SRiver Riddle 25136550692SRiver Riddle unsigned oldNumArgs = oldType.getNumInputs(); 25236550692SRiver Riddle SmallVector<Type, 4> newInputs; 25336550692SRiver Riddle newInputs.reserve(oldNumArgs); 25436550692SRiver Riddle for (unsigned i = 0; i != oldNumArgs; ++i) 25536550692SRiver Riddle if (!mapper.contains(getArgument(i))) 25636550692SRiver Riddle newInputs.push_back(oldType.getInput(i)); 25736550692SRiver Riddle 25836550692SRiver Riddle /// If any of the arguments were dropped, update the type and drop any 25936550692SRiver Riddle /// necessary argument attributes. 26036550692SRiver Riddle if (newInputs.size() != oldNumArgs) { 26136550692SRiver Riddle newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, 26236550692SRiver Riddle oldType.getResults())); 26336550692SRiver Riddle 26436550692SRiver Riddle if (ArrayAttr argAttrs = getAllArgAttrs()) { 26536550692SRiver Riddle SmallVector<Attribute> newArgAttrs; 26636550692SRiver Riddle newArgAttrs.reserve(newInputs.size()); 26736550692SRiver Riddle for (unsigned i = 0; i != oldNumArgs; ++i) 26836550692SRiver Riddle if (!mapper.contains(getArgument(i))) 26936550692SRiver Riddle newArgAttrs.push_back(argAttrs[i]); 27036550692SRiver Riddle newFunc.setAllArgAttrs(newArgAttrs); 27136550692SRiver Riddle } 27236550692SRiver Riddle } 27336550692SRiver Riddle } 27436550692SRiver Riddle 27536550692SRiver Riddle /// Clone the current function into the new one and return it. 27636550692SRiver Riddle cloneInto(newFunc, mapper); 27736550692SRiver Riddle return newFunc; 27836550692SRiver Riddle } 27936550692SRiver Riddle FuncOp FuncOp::clone() { 2804d67b278SJeff Niu IRMapping mapper; 28136550692SRiver Riddle return clone(mapper); 28236550692SRiver Riddle } 28336550692SRiver Riddle 28436550692SRiver Riddle //===----------------------------------------------------------------------===// 28523aa5a74SRiver Riddle // ReturnOp 28623aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 28723aa5a74SRiver Riddle 28823aa5a74SRiver Riddle LogicalResult ReturnOp::verify() { 28923aa5a74SRiver Riddle auto function = cast<FuncOp>((*this)->getParentOp()); 29023aa5a74SRiver Riddle 29123aa5a74SRiver Riddle // The operand number and types must match the function signature. 2924a3460a7SRiver Riddle const auto &results = function.getFunctionType().getResults(); 29323aa5a74SRiver Riddle if (getNumOperands() != results.size()) 29423aa5a74SRiver Riddle return emitOpError("has ") 29523aa5a74SRiver Riddle << getNumOperands() << " operands, but enclosing function (@" 29623aa5a74SRiver Riddle << function.getName() << ") returns " << results.size(); 29723aa5a74SRiver Riddle 29823aa5a74SRiver Riddle for (unsigned i = 0, e = results.size(); i != e; ++i) 29923aa5a74SRiver Riddle if (getOperand(i).getType() != results[i]) 30023aa5a74SRiver Riddle return emitError() << "type of return operand " << i << " (" 30123aa5a74SRiver Riddle << getOperand(i).getType() 30223aa5a74SRiver Riddle << ") doesn't match function result type (" 30323aa5a74SRiver Riddle << results[i] << ")" 30423aa5a74SRiver Riddle << " in function @" << function.getName(); 30523aa5a74SRiver Riddle 30623aa5a74SRiver Riddle return success(); 30723aa5a74SRiver Riddle } 30823aa5a74SRiver Riddle 30923aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 31023aa5a74SRiver Riddle // TableGen'd op method definitions 31123aa5a74SRiver Riddle //===----------------------------------------------------------------------===// 31223aa5a74SRiver Riddle 31323aa5a74SRiver Riddle #define GET_OP_CLASSES 31423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 315