xref: /llvm-project/mlir/lib/Interfaces/FunctionInterfaces.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
134a35a8bSMartin Erhart //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
234a35a8bSMartin Erhart //
334a35a8bSMartin Erhart // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
434a35a8bSMartin Erhart // See https://llvm.org/LICENSE.txt for license information.
534a35a8bSMartin Erhart // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
634a35a8bSMartin Erhart //
734a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
834a35a8bSMartin Erhart 
934a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
1034a35a8bSMartin Erhart 
1134a35a8bSMartin Erhart using namespace mlir;
1234a35a8bSMartin Erhart 
1334a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
1434a35a8bSMartin Erhart // Tablegen Interface Definitions
1534a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
1634a35a8bSMartin Erhart 
1734a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
1834a35a8bSMartin Erhart 
1934a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
2034a35a8bSMartin Erhart // Function Arguments and Results.
2134a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
2234a35a8bSMartin Erhart 
2334a35a8bSMartin Erhart static bool isEmptyAttrDict(Attribute attr) {
2434a35a8bSMartin Erhart   return llvm::cast<DictionaryAttr>(attr).empty();
2534a35a8bSMartin Erhart }
2634a35a8bSMartin Erhart 
2734a35a8bSMartin Erhart DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
2834a35a8bSMartin Erhart                                                        unsigned index) {
2934a35a8bSMartin Erhart   ArrayAttr attrs = op.getArgAttrsAttr();
3034a35a8bSMartin Erhart   DictionaryAttr argAttrs =
3134a35a8bSMartin Erhart       attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
3234a35a8bSMartin Erhart   return argAttrs;
3334a35a8bSMartin Erhart }
3434a35a8bSMartin Erhart 
3534a35a8bSMartin Erhart DictionaryAttr
3634a35a8bSMartin Erhart function_interface_impl::getResultAttrDict(FunctionOpInterface op,
3734a35a8bSMartin Erhart                                            unsigned index) {
3834a35a8bSMartin Erhart   ArrayAttr attrs = op.getResAttrsAttr();
3934a35a8bSMartin Erhart   DictionaryAttr resAttrs =
4034a35a8bSMartin Erhart       attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
4134a35a8bSMartin Erhart   return resAttrs;
4234a35a8bSMartin Erhart }
4334a35a8bSMartin Erhart 
4434a35a8bSMartin Erhart ArrayRef<NamedAttribute>
4534a35a8bSMartin Erhart function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
4634a35a8bSMartin Erhart   auto argDict = getArgAttrDict(op, index);
4734a35a8bSMartin Erhart   return argDict ? argDict.getValue() : std::nullopt;
4834a35a8bSMartin Erhart }
4934a35a8bSMartin Erhart 
5034a35a8bSMartin Erhart ArrayRef<NamedAttribute>
5134a35a8bSMartin Erhart function_interface_impl::getResultAttrs(FunctionOpInterface op,
5234a35a8bSMartin Erhart                                         unsigned index) {
5334a35a8bSMartin Erhart   auto resultDict = getResultAttrDict(op, index);
5434a35a8bSMartin Erhart   return resultDict ? resultDict.getValue() : std::nullopt;
5534a35a8bSMartin Erhart }
5634a35a8bSMartin Erhart 
5734a35a8bSMartin Erhart /// Get either the argument or result attributes array.
5834a35a8bSMartin Erhart template <bool isArg>
5934a35a8bSMartin Erhart static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
6034a35a8bSMartin Erhart   if constexpr (isArg)
6134a35a8bSMartin Erhart     return op.getArgAttrsAttr();
6234a35a8bSMartin Erhart   else
6334a35a8bSMartin Erhart     return op.getResAttrsAttr();
6434a35a8bSMartin Erhart }
6534a35a8bSMartin Erhart 
6634a35a8bSMartin Erhart /// Set either the argument or result attributes array.
6734a35a8bSMartin Erhart template <bool isArg>
6834a35a8bSMartin Erhart static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
6934a35a8bSMartin Erhart   if constexpr (isArg)
7034a35a8bSMartin Erhart     op.setArgAttrsAttr(attrs);
7134a35a8bSMartin Erhart   else
7234a35a8bSMartin Erhart     op.setResAttrsAttr(attrs);
7334a35a8bSMartin Erhart }
7434a35a8bSMartin Erhart 
7534a35a8bSMartin Erhart /// Erase either the argument or result attributes array.
7634a35a8bSMartin Erhart template <bool isArg>
7734a35a8bSMartin Erhart static void removeArgResAttrs(FunctionOpInterface op) {
7834a35a8bSMartin Erhart   if constexpr (isArg)
7934a35a8bSMartin Erhart     op.removeArgAttrsAttr();
8034a35a8bSMartin Erhart   else
8134a35a8bSMartin Erhart     op.removeResAttrsAttr();
8234a35a8bSMartin Erhart }
8334a35a8bSMartin Erhart 
8434a35a8bSMartin Erhart /// Set all of the argument or result attribute dictionaries for a function.
8534a35a8bSMartin Erhart template <bool isArg>
8634a35a8bSMartin Erhart static void setAllArgResAttrDicts(FunctionOpInterface op,
8734a35a8bSMartin Erhart                                   ArrayRef<Attribute> attrs) {
8834a35a8bSMartin Erhart   if (llvm::all_of(attrs, isEmptyAttrDict))
8934a35a8bSMartin Erhart     removeArgResAttrs<isArg>(op);
9034a35a8bSMartin Erhart   else
9134a35a8bSMartin Erhart     setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
9234a35a8bSMartin Erhart }
9334a35a8bSMartin Erhart 
9434a35a8bSMartin Erhart void function_interface_impl::setAllArgAttrDicts(
9534a35a8bSMartin Erhart     FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
9634a35a8bSMartin Erhart   setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
9734a35a8bSMartin Erhart }
9834a35a8bSMartin Erhart 
9934a35a8bSMartin Erhart void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
10034a35a8bSMartin Erhart                                                  ArrayRef<Attribute> attrs) {
10134a35a8bSMartin Erhart   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
10234a35a8bSMartin Erhart     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
10334a35a8bSMartin Erhart   });
10434a35a8bSMartin Erhart   setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
10534a35a8bSMartin Erhart }
10634a35a8bSMartin Erhart 
10734a35a8bSMartin Erhart void function_interface_impl::setAllResultAttrDicts(
10834a35a8bSMartin Erhart     FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
10934a35a8bSMartin Erhart   setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
11034a35a8bSMartin Erhart }
11134a35a8bSMartin Erhart 
11234a35a8bSMartin Erhart void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
11334a35a8bSMartin Erhart                                                     ArrayRef<Attribute> attrs) {
11434a35a8bSMartin Erhart   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
11534a35a8bSMartin Erhart     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
11634a35a8bSMartin Erhart   });
11734a35a8bSMartin Erhart   setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
11834a35a8bSMartin Erhart }
11934a35a8bSMartin Erhart 
12034a35a8bSMartin Erhart /// Update the given index into an argument or result attribute dictionary.
12134a35a8bSMartin Erhart template <bool isArg>
12234a35a8bSMartin Erhart static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
12334a35a8bSMartin Erhart                               unsigned index, DictionaryAttr attrs) {
12434a35a8bSMartin Erhart   ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
12534a35a8bSMartin Erhart   if (!allAttrs) {
12634a35a8bSMartin Erhart     if (attrs.empty())
12734a35a8bSMartin Erhart       return;
12834a35a8bSMartin Erhart 
12934a35a8bSMartin Erhart     // If this attribute is not empty, we need to create a new attribute array.
13034a35a8bSMartin Erhart     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
13134a35a8bSMartin Erhart                                        DictionaryAttr::get(op->getContext()));
13234a35a8bSMartin Erhart     newAttrs[index] = attrs;
13334a35a8bSMartin Erhart     setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
13434a35a8bSMartin Erhart     return;
13534a35a8bSMartin Erhart   }
13634a35a8bSMartin Erhart   // Check to see if the attribute is different from what we already have.
13734a35a8bSMartin Erhart   if (allAttrs[index] == attrs)
13834a35a8bSMartin Erhart     return;
13934a35a8bSMartin Erhart 
14034a35a8bSMartin Erhart   // If it is, check to see if the attribute array would now contain only empty
14134a35a8bSMartin Erhart   // dictionaries.
14234a35a8bSMartin Erhart   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
14334a35a8bSMartin Erhart   if (attrs.empty() &&
14434a35a8bSMartin Erhart       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
14534a35a8bSMartin Erhart       llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
14634a35a8bSMartin Erhart     return removeArgResAttrs<isArg>(op);
14734a35a8bSMartin Erhart 
14834a35a8bSMartin Erhart   // Otherwise, create a new attribute array with the updated dictionary.
149*5262865aSKazu Hirata   SmallVector<Attribute, 8> newAttrs(rawAttrArray);
15034a35a8bSMartin Erhart   newAttrs[index] = attrs;
15134a35a8bSMartin Erhart   setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
15234a35a8bSMartin Erhart }
15334a35a8bSMartin Erhart 
15434a35a8bSMartin Erhart void function_interface_impl::setArgAttrs(FunctionOpInterface op,
15534a35a8bSMartin Erhart                                           unsigned index,
15634a35a8bSMartin Erhart                                           ArrayRef<NamedAttribute> attributes) {
15734a35a8bSMartin Erhart   assert(index < op.getNumArguments() && "invalid argument number");
15834a35a8bSMartin Erhart   return setArgResAttrDict</*isArg=*/true>(
15934a35a8bSMartin Erhart       op, op.getNumArguments(), index,
16034a35a8bSMartin Erhart       DictionaryAttr::get(op->getContext(), attributes));
16134a35a8bSMartin Erhart }
16234a35a8bSMartin Erhart 
16334a35a8bSMartin Erhart void function_interface_impl::setArgAttrs(FunctionOpInterface op,
16434a35a8bSMartin Erhart                                           unsigned index,
16534a35a8bSMartin Erhart                                           DictionaryAttr attributes) {
16634a35a8bSMartin Erhart   return setArgResAttrDict</*isArg=*/true>(
16734a35a8bSMartin Erhart       op, op.getNumArguments(), index,
16834a35a8bSMartin Erhart       attributes ? attributes : DictionaryAttr::get(op->getContext()));
16934a35a8bSMartin Erhart }
17034a35a8bSMartin Erhart 
17134a35a8bSMartin Erhart void function_interface_impl::setResultAttrs(
17234a35a8bSMartin Erhart     FunctionOpInterface op, unsigned index,
17334a35a8bSMartin Erhart     ArrayRef<NamedAttribute> attributes) {
17434a35a8bSMartin Erhart   assert(index < op.getNumResults() && "invalid result number");
17534a35a8bSMartin Erhart   return setArgResAttrDict</*isArg=*/false>(
17634a35a8bSMartin Erhart       op, op.getNumResults(), index,
17734a35a8bSMartin Erhart       DictionaryAttr::get(op->getContext(), attributes));
17834a35a8bSMartin Erhart }
17934a35a8bSMartin Erhart 
18034a35a8bSMartin Erhart void function_interface_impl::setResultAttrs(FunctionOpInterface op,
18134a35a8bSMartin Erhart                                              unsigned index,
18234a35a8bSMartin Erhart                                              DictionaryAttr attributes) {
18334a35a8bSMartin Erhart   assert(index < op.getNumResults() && "invalid result number");
18434a35a8bSMartin Erhart   return setArgResAttrDict</*isArg=*/false>(
18534a35a8bSMartin Erhart       op, op.getNumResults(), index,
18634a35a8bSMartin Erhart       attributes ? attributes : DictionaryAttr::get(op->getContext()));
18734a35a8bSMartin Erhart }
18834a35a8bSMartin Erhart 
18934a35a8bSMartin Erhart void function_interface_impl::insertFunctionArguments(
19034a35a8bSMartin Erhart     FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
19134a35a8bSMartin Erhart     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
19234a35a8bSMartin Erhart     unsigned originalNumArgs, Type newType) {
19334a35a8bSMartin Erhart   assert(argIndices.size() == argTypes.size());
19434a35a8bSMartin Erhart   assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
19534a35a8bSMartin Erhart   assert(argIndices.size() == argLocs.size());
19634a35a8bSMartin Erhart   if (argIndices.empty())
19734a35a8bSMartin Erhart     return;
19834a35a8bSMartin Erhart 
19934a35a8bSMartin Erhart   // There are 3 things that need to be updated:
20034a35a8bSMartin Erhart   // - Function type.
20134a35a8bSMartin Erhart   // - Arg attrs.
20234a35a8bSMartin Erhart   // - Block arguments of entry block.
20334a35a8bSMartin Erhart   Block &entry = op->getRegion(0).front();
20434a35a8bSMartin Erhart 
20534a35a8bSMartin Erhart   // Update the argument attributes of the function.
20634a35a8bSMartin Erhart   ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
20734a35a8bSMartin Erhart   if (oldArgAttrs || !argAttrs.empty()) {
20834a35a8bSMartin Erhart     SmallVector<DictionaryAttr, 4> newArgAttrs;
20934a35a8bSMartin Erhart     newArgAttrs.reserve(originalNumArgs + argIndices.size());
21034a35a8bSMartin Erhart     unsigned oldIdx = 0;
21134a35a8bSMartin Erhart     auto migrate = [&](unsigned untilIdx) {
21234a35a8bSMartin Erhart       if (!oldArgAttrs) {
21334a35a8bSMartin Erhart         newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
21434a35a8bSMartin Erhart       } else {
21534a35a8bSMartin Erhart         auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
21634a35a8bSMartin Erhart         newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
21734a35a8bSMartin Erhart                            oldArgAttrRange.begin() + untilIdx);
21834a35a8bSMartin Erhart       }
21934a35a8bSMartin Erhart       oldIdx = untilIdx;
22034a35a8bSMartin Erhart     };
22134a35a8bSMartin Erhart     for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
22234a35a8bSMartin Erhart       migrate(argIndices[i]);
22334a35a8bSMartin Erhart       newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
22434a35a8bSMartin Erhart     }
22534a35a8bSMartin Erhart     migrate(originalNumArgs);
22634a35a8bSMartin Erhart     setAllArgAttrDicts(op, newArgAttrs);
22734a35a8bSMartin Erhart   }
22834a35a8bSMartin Erhart 
22934a35a8bSMartin Erhart   // Update the function type and any entry block arguments.
23034a35a8bSMartin Erhart   op.setFunctionTypeAttr(TypeAttr::get(newType));
23134a35a8bSMartin Erhart   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
23234a35a8bSMartin Erhart     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
23334a35a8bSMartin Erhart }
23434a35a8bSMartin Erhart 
23534a35a8bSMartin Erhart void function_interface_impl::insertFunctionResults(
23634a35a8bSMartin Erhart     FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
23734a35a8bSMartin Erhart     TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
23834a35a8bSMartin Erhart     unsigned originalNumResults, Type newType) {
23934a35a8bSMartin Erhart   assert(resultIndices.size() == resultTypes.size());
24034a35a8bSMartin Erhart   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
24134a35a8bSMartin Erhart   if (resultIndices.empty())
24234a35a8bSMartin Erhart     return;
24334a35a8bSMartin Erhart 
24434a35a8bSMartin Erhart   // There are 2 things that need to be updated:
24534a35a8bSMartin Erhart   // - Function type.
24634a35a8bSMartin Erhart   // - Result attrs.
24734a35a8bSMartin Erhart 
24834a35a8bSMartin Erhart   // Update the result attributes of the function.
24934a35a8bSMartin Erhart   ArrayAttr oldResultAttrs = op.getResAttrsAttr();
25034a35a8bSMartin Erhart   if (oldResultAttrs || !resultAttrs.empty()) {
25134a35a8bSMartin Erhart     SmallVector<DictionaryAttr, 4> newResultAttrs;
25234a35a8bSMartin Erhart     newResultAttrs.reserve(originalNumResults + resultIndices.size());
25334a35a8bSMartin Erhart     unsigned oldIdx = 0;
25434a35a8bSMartin Erhart     auto migrate = [&](unsigned untilIdx) {
25534a35a8bSMartin Erhart       if (!oldResultAttrs) {
25634a35a8bSMartin Erhart         newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
25734a35a8bSMartin Erhart       } else {
25834a35a8bSMartin Erhart         auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
25934a35a8bSMartin Erhart         newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
26034a35a8bSMartin Erhart                               oldResultAttrsRange.begin() + untilIdx);
26134a35a8bSMartin Erhart       }
26234a35a8bSMartin Erhart       oldIdx = untilIdx;
26334a35a8bSMartin Erhart     };
26434a35a8bSMartin Erhart     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
26534a35a8bSMartin Erhart       migrate(resultIndices[i]);
26634a35a8bSMartin Erhart       newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
26734a35a8bSMartin Erhart                                                    : resultAttrs[i]);
26834a35a8bSMartin Erhart     }
26934a35a8bSMartin Erhart     migrate(originalNumResults);
27034a35a8bSMartin Erhart     setAllResultAttrDicts(op, newResultAttrs);
27134a35a8bSMartin Erhart   }
27234a35a8bSMartin Erhart 
27334a35a8bSMartin Erhart   // Update the function type.
27434a35a8bSMartin Erhart   op.setFunctionTypeAttr(TypeAttr::get(newType));
27534a35a8bSMartin Erhart }
27634a35a8bSMartin Erhart 
27734a35a8bSMartin Erhart void function_interface_impl::eraseFunctionArguments(
27834a35a8bSMartin Erhart     FunctionOpInterface op, const BitVector &argIndices, Type newType) {
27934a35a8bSMartin Erhart   // There are 3 things that need to be updated:
28034a35a8bSMartin Erhart   // - Function type.
28134a35a8bSMartin Erhart   // - Arg attrs.
28234a35a8bSMartin Erhart   // - Block arguments of entry block.
28334a35a8bSMartin Erhart   Block &entry = op->getRegion(0).front();
28434a35a8bSMartin Erhart 
28534a35a8bSMartin Erhart   // Update the argument attributes of the function.
28634a35a8bSMartin Erhart   if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
28734a35a8bSMartin Erhart     SmallVector<DictionaryAttr, 4> newArgAttrs;
28834a35a8bSMartin Erhart     newArgAttrs.reserve(argAttrs.size());
28934a35a8bSMartin Erhart     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
29034a35a8bSMartin Erhart       if (!argIndices[i])
29134a35a8bSMartin Erhart         newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
29234a35a8bSMartin Erhart     setAllArgAttrDicts(op, newArgAttrs);
29334a35a8bSMartin Erhart   }
29434a35a8bSMartin Erhart 
29534a35a8bSMartin Erhart   // Update the function type and any entry block arguments.
29634a35a8bSMartin Erhart   op.setFunctionTypeAttr(TypeAttr::get(newType));
29734a35a8bSMartin Erhart   entry.eraseArguments(argIndices);
29834a35a8bSMartin Erhart }
29934a35a8bSMartin Erhart 
30034a35a8bSMartin Erhart void function_interface_impl::eraseFunctionResults(
30134a35a8bSMartin Erhart     FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
30234a35a8bSMartin Erhart   // There are 2 things that need to be updated:
30334a35a8bSMartin Erhart   // - Function type.
30434a35a8bSMartin Erhart   // - Result attrs.
30534a35a8bSMartin Erhart 
30634a35a8bSMartin Erhart   // Update the result attributes of the function.
30734a35a8bSMartin Erhart   if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
30834a35a8bSMartin Erhart     SmallVector<DictionaryAttr, 4> newResultAttrs;
30934a35a8bSMartin Erhart     newResultAttrs.reserve(resAttrs.size());
31034a35a8bSMartin Erhart     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
31134a35a8bSMartin Erhart       if (!resultIndices[i])
31234a35a8bSMartin Erhart         newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
31334a35a8bSMartin Erhart     setAllResultAttrDicts(op, newResultAttrs);
31434a35a8bSMartin Erhart   }
31534a35a8bSMartin Erhart 
31634a35a8bSMartin Erhart   // Update the function type.
31734a35a8bSMartin Erhart   op.setFunctionTypeAttr(TypeAttr::get(newType));
31834a35a8bSMartin Erhart }
31934a35a8bSMartin Erhart 
32034a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
32134a35a8bSMartin Erhart // Function type signature.
32234a35a8bSMartin Erhart //===----------------------------------------------------------------------===//
32334a35a8bSMartin Erhart 
32434a35a8bSMartin Erhart void function_interface_impl::setFunctionType(FunctionOpInterface op,
32534a35a8bSMartin Erhart                                               Type newType) {
32634a35a8bSMartin Erhart   unsigned oldNumArgs = op.getNumArguments();
32734a35a8bSMartin Erhart   unsigned oldNumResults = op.getNumResults();
32834a35a8bSMartin Erhart   op.setFunctionTypeAttr(TypeAttr::get(newType));
32934a35a8bSMartin Erhart   unsigned newNumArgs = op.getNumArguments();
33034a35a8bSMartin Erhart   unsigned newNumResults = op.getNumResults();
33134a35a8bSMartin Erhart 
33234a35a8bSMartin Erhart   // Functor used to update the argument and result attributes of the function.
33334a35a8bSMartin Erhart   auto emptyDict = DictionaryAttr::get(op.getContext());
33434a35a8bSMartin Erhart   auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
33534a35a8bSMartin Erhart     constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
33634a35a8bSMartin Erhart 
33734a35a8bSMartin Erhart     if (oldCount == newCount)
33834a35a8bSMartin Erhart       return;
33934a35a8bSMartin Erhart     // The new type has no arguments/results, just drop the attribute.
34034a35a8bSMartin Erhart     if (newCount == 0)
34134a35a8bSMartin Erhart       return removeArgResAttrs<isArgVal>(op);
34234a35a8bSMartin Erhart     ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
34334a35a8bSMartin Erhart     if (!attrs)
34434a35a8bSMartin Erhart       return;
34534a35a8bSMartin Erhart 
34634a35a8bSMartin Erhart     // The new type has less arguments/results, take the first N attributes.
34734a35a8bSMartin Erhart     if (newCount < oldCount)
34834a35a8bSMartin Erhart       return setAllArgResAttrDicts<isArgVal>(
34934a35a8bSMartin Erhart           op, attrs.getValue().take_front(newCount));
35034a35a8bSMartin Erhart 
35134a35a8bSMartin Erhart     // Otherwise, the new type has more arguments/results. Initialize the new
35234a35a8bSMartin Erhart     // arguments/results with empty dictionary attributes.
35334a35a8bSMartin Erhart     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
35434a35a8bSMartin Erhart     newAttrs.resize(newCount, emptyDict);
35534a35a8bSMartin Erhart     setAllArgResAttrDicts<isArgVal>(op, newAttrs);
35634a35a8bSMartin Erhart   };
35734a35a8bSMartin Erhart 
35834a35a8bSMartin Erhart   // Update the argument and result attributes.
35934a35a8bSMartin Erhart   updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
36034a35a8bSMartin Erhart   updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
36134a35a8bSMartin Erhart }
362