//===- FunctionSupport.cpp - Utility types for function-like ops ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/FunctionInterfaces.h" using namespace mlir; //===----------------------------------------------------------------------===// // Tablegen Interface Definitions //===----------------------------------------------------------------------===// #include "mlir/Interfaces/FunctionInterfaces.cpp.inc" //===----------------------------------------------------------------------===// // Function Arguments and Results. //===----------------------------------------------------------------------===// static bool isEmptyAttrDict(Attribute attr) { return llvm::cast(attr).empty(); } DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, unsigned index) { ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return argAttrs; } DictionaryAttr function_interface_impl::getResultAttrDict(FunctionOpInterface op, unsigned index) { ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); return resAttrs; } ArrayRef function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { auto argDict = getArgAttrDict(op, index); return argDict ? argDict.getValue() : std::nullopt; } ArrayRef function_interface_impl::getResultAttrs(FunctionOpInterface op, unsigned index) { auto resultDict = getResultAttrDict(op, index); return resultDict ? resultDict.getValue() : std::nullopt; } /// Get either the argument or result attributes array. template static ArrayAttr getArgResAttrs(FunctionOpInterface op) { if constexpr (isArg) return op.getArgAttrsAttr(); else return op.getResAttrsAttr(); } /// Set either the argument or result attributes array. template static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { if constexpr (isArg) op.setArgAttrsAttr(attrs); else op.setResAttrsAttr(attrs); } /// Erase either the argument or result attributes array. template static void removeArgResAttrs(FunctionOpInterface op) { if constexpr (isArg) op.removeArgAttrsAttr(); else op.removeResAttrsAttr(); } /// Set all of the argument or result attribute dictionaries for a function. template static void setAllArgResAttrDicts(FunctionOpInterface op, ArrayRef attrs) { if (llvm::all_of(attrs, isEmptyAttrDict)) removeArgResAttrs(op); else setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); } void function_interface_impl::setAllArgAttrDicts( FunctionOpInterface op, ArrayRef attrs) { setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); } void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs) { auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { return !attr ? DictionaryAttr::get(op->getContext()) : attr; }); setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); } void function_interface_impl::setAllResultAttrDicts( FunctionOpInterface op, ArrayRef attrs) { setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); } void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, ArrayRef attrs) { auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { return !attr ? DictionaryAttr::get(op->getContext()) : attr; }); setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); } /// Update the given index into an argument or result attribute dictionary. template static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs) { ArrayAttr allAttrs = getArgResAttrs(op); if (!allAttrs) { if (attrs.empty()) return; // If this attribute is not empty, we need to create a new attribute array. SmallVector newAttrs(numTotalIndices, DictionaryAttr::get(op->getContext())); newAttrs[index] = attrs; setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); return; } // Check to see if the attribute is different from what we already have. if (allAttrs[index] == attrs) return; // If it is, check to see if the attribute array would now contain only empty // dictionaries. ArrayRef rawAttrArray = allAttrs.getValue(); if (attrs.empty() && llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) return removeArgResAttrs(op); // Otherwise, create a new attribute array with the updated dictionary. SmallVector newAttrs(rawAttrArray); newAttrs[index] = attrs; setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); } void function_interface_impl::setArgAttrs(FunctionOpInterface op, unsigned index, ArrayRef attributes) { assert(index < op.getNumArguments() && "invalid argument number"); return setArgResAttrDict( op, op.getNumArguments(), index, DictionaryAttr::get(op->getContext(), attributes)); } void function_interface_impl::setArgAttrs(FunctionOpInterface op, unsigned index, DictionaryAttr attributes) { return setArgResAttrDict( op, op.getNumArguments(), index, attributes ? attributes : DictionaryAttr::get(op->getContext())); } void function_interface_impl::setResultAttrs( FunctionOpInterface op, unsigned index, ArrayRef attributes) { assert(index < op.getNumResults() && "invalid result number"); return setArgResAttrDict( op, op.getNumResults(), index, DictionaryAttr::get(op->getContext(), attributes)); } void function_interface_impl::setResultAttrs(FunctionOpInterface op, unsigned index, DictionaryAttr attributes) { assert(index < op.getNumResults() && "invalid result number"); return setArgResAttrDict( op, op.getNumResults(), index, attributes ? attributes : DictionaryAttr::get(op->getContext())); } void function_interface_impl::insertFunctionArguments( FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { assert(argIndices.size() == argTypes.size()); assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); assert(argIndices.size() == argLocs.size()); if (argIndices.empty()) return; // There are 3 things that need to be updated: // - Function type. // - Arg attrs. // - Block arguments of entry block. Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); if (oldArgAttrs || !argAttrs.empty()) { SmallVector newArgAttrs; newArgAttrs.reserve(originalNumArgs + argIndices.size()); unsigned oldIdx = 0; auto migrate = [&](unsigned untilIdx) { if (!oldArgAttrs) { newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); } else { auto oldArgAttrRange = oldArgAttrs.getAsRange(); newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, oldArgAttrRange.begin() + untilIdx); } oldIdx = untilIdx; }; for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { migrate(argIndices[i]); newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); } migrate(originalNumArgs); setAllArgAttrDicts(op, newArgAttrs); } // Update the function type and any entry block arguments. op.setFunctionTypeAttr(TypeAttr::get(newType)); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } void function_interface_impl::insertFunctionResults( FunctionOpInterface op, ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType) { assert(resultIndices.size() == resultTypes.size()); assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); if (resultIndices.empty()) return; // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. ArrayAttr oldResultAttrs = op.getResAttrsAttr(); if (oldResultAttrs || !resultAttrs.empty()) { SmallVector newResultAttrs; newResultAttrs.reserve(originalNumResults + resultIndices.size()); unsigned oldIdx = 0; auto migrate = [&](unsigned untilIdx) { if (!oldResultAttrs) { newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); } else { auto oldResultAttrsRange = oldResultAttrs.getAsRange(); newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, oldResultAttrsRange.begin() + untilIdx); } oldIdx = untilIdx; }; for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { migrate(resultIndices[i]); newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} : resultAttrs[i]); } migrate(originalNumResults); setAllResultAttrDicts(op, newResultAttrs); } // Update the function type. op.setFunctionTypeAttr(TypeAttr::get(newType)); } void function_interface_impl::eraseFunctionArguments( FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. // - Block arguments of entry block. Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) if (!argIndices[i]) newArgAttrs.emplace_back(llvm::cast(argAttrs[i])); setAllArgAttrDicts(op, newArgAttrs); } // Update the function type and any entry block arguments. op.setFunctionTypeAttr(TypeAttr::get(newType)); entry.eraseArguments(argIndices); } void function_interface_impl::eraseFunctionResults( FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. if (ArrayAttr resAttrs = op.getResAttrsAttr()) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) if (!resultIndices[i]) newResultAttrs.emplace_back(llvm::cast(resAttrs[i])); setAllResultAttrDicts(op, newResultAttrs); } // Update the function type. op.setFunctionTypeAttr(TypeAttr::get(newType)); } //===----------------------------------------------------------------------===// // Function type signature. //===----------------------------------------------------------------------===// void function_interface_impl::setFunctionType(FunctionOpInterface op, Type newType) { unsigned oldNumArgs = op.getNumArguments(); unsigned oldNumResults = op.getNumResults(); op.setFunctionTypeAttr(TypeAttr::get(newType)); unsigned newNumArgs = op.getNumArguments(); unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. auto emptyDict = DictionaryAttr::get(op.getContext()); auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { constexpr bool isArgVal = std::is_same_v; if (oldCount == newCount) return; // The new type has no arguments/results, just drop the attribute. if (newCount == 0) return removeArgResAttrs(op); ArrayAttr attrs = getArgResAttrs(op); if (!attrs) return; // The new type has less arguments/results, take the first N attributes. if (newCount < oldCount) return setAllArgResAttrDicts( op, attrs.getValue().take_front(newCount)); // Otherwise, the new type has more arguments/results. Initialize the new // arguments/results with empty dictionary attributes. SmallVector newAttrs(attrs.begin(), attrs.end()); newAttrs.resize(newCount, emptyDict); setAllArgResAttrDicts(op, newAttrs); }; // Update the argument and result attributes. updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); updateAttrFn(std::false_type{}, oldNumResults, newNumResults); }