1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/FunctionInterfaces.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // Tablegen Interface Definitions 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Interfaces/FunctionInterfaces.cpp.inc" 18 19 //===----------------------------------------------------------------------===// 20 // Function Arguments and Results. 21 //===----------------------------------------------------------------------===// 22 23 static bool isEmptyAttrDict(Attribute attr) { 24 return llvm::cast<DictionaryAttr>(attr).empty(); 25 } 26 27 DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, 28 unsigned index) { 29 ArrayAttr attrs = op.getArgAttrsAttr(); 30 DictionaryAttr argAttrs = 31 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); 32 return argAttrs; 33 } 34 35 DictionaryAttr 36 function_interface_impl::getResultAttrDict(FunctionOpInterface op, 37 unsigned index) { 38 ArrayAttr attrs = op.getResAttrsAttr(); 39 DictionaryAttr resAttrs = 40 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); 41 return resAttrs; 42 } 43 44 ArrayRef<NamedAttribute> 45 function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { 46 auto argDict = getArgAttrDict(op, index); 47 return argDict ? argDict.getValue() : std::nullopt; 48 } 49 50 ArrayRef<NamedAttribute> 51 function_interface_impl::getResultAttrs(FunctionOpInterface op, 52 unsigned index) { 53 auto resultDict = getResultAttrDict(op, index); 54 return resultDict ? resultDict.getValue() : std::nullopt; 55 } 56 57 /// Get either the argument or result attributes array. 58 template <bool isArg> 59 static ArrayAttr getArgResAttrs(FunctionOpInterface op) { 60 if constexpr (isArg) 61 return op.getArgAttrsAttr(); 62 else 63 return op.getResAttrsAttr(); 64 } 65 66 /// Set either the argument or result attributes array. 67 template <bool isArg> 68 static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { 69 if constexpr (isArg) 70 op.setArgAttrsAttr(attrs); 71 else 72 op.setResAttrsAttr(attrs); 73 } 74 75 /// Erase either the argument or result attributes array. 76 template <bool isArg> 77 static void removeArgResAttrs(FunctionOpInterface op) { 78 if constexpr (isArg) 79 op.removeArgAttrsAttr(); 80 else 81 op.removeResAttrsAttr(); 82 } 83 84 /// Set all of the argument or result attribute dictionaries for a function. 85 template <bool isArg> 86 static void setAllArgResAttrDicts(FunctionOpInterface op, 87 ArrayRef<Attribute> attrs) { 88 if (llvm::all_of(attrs, isEmptyAttrDict)) 89 removeArgResAttrs<isArg>(op); 90 else 91 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs)); 92 } 93 94 void function_interface_impl::setAllArgAttrDicts( 95 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { 96 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 97 } 98 99 void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, 100 ArrayRef<Attribute> attrs) { 101 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 102 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 103 }); 104 setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs)); 105 } 106 107 void function_interface_impl::setAllResultAttrDicts( 108 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { 109 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 110 } 111 112 void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, 113 ArrayRef<Attribute> attrs) { 114 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 115 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 116 }); 117 setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs)); 118 } 119 120 /// Update the given index into an argument or result attribute dictionary. 121 template <bool isArg> 122 static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, 123 unsigned index, DictionaryAttr attrs) { 124 ArrayAttr allAttrs = getArgResAttrs<isArg>(op); 125 if (!allAttrs) { 126 if (attrs.empty()) 127 return; 128 129 // If this attribute is not empty, we need to create a new attribute array. 130 SmallVector<Attribute, 8> newAttrs(numTotalIndices, 131 DictionaryAttr::get(op->getContext())); 132 newAttrs[index] = attrs; 133 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); 134 return; 135 } 136 // Check to see if the attribute is different from what we already have. 137 if (allAttrs[index] == attrs) 138 return; 139 140 // If it is, check to see if the attribute array would now contain only empty 141 // dictionaries. 142 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); 143 if (attrs.empty() && 144 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && 145 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) 146 return removeArgResAttrs<isArg>(op); 147 148 // Otherwise, create a new attribute array with the updated dictionary. 149 SmallVector<Attribute, 8> newAttrs(rawAttrArray); 150 newAttrs[index] = attrs; 151 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); 152 } 153 154 void function_interface_impl::setArgAttrs(FunctionOpInterface op, 155 unsigned index, 156 ArrayRef<NamedAttribute> attributes) { 157 assert(index < op.getNumArguments() && "invalid argument number"); 158 return setArgResAttrDict</*isArg=*/true>( 159 op, op.getNumArguments(), index, 160 DictionaryAttr::get(op->getContext(), attributes)); 161 } 162 163 void function_interface_impl::setArgAttrs(FunctionOpInterface op, 164 unsigned index, 165 DictionaryAttr attributes) { 166 return setArgResAttrDict</*isArg=*/true>( 167 op, op.getNumArguments(), index, 168 attributes ? attributes : DictionaryAttr::get(op->getContext())); 169 } 170 171 void function_interface_impl::setResultAttrs( 172 FunctionOpInterface op, unsigned index, 173 ArrayRef<NamedAttribute> attributes) { 174 assert(index < op.getNumResults() && "invalid result number"); 175 return setArgResAttrDict</*isArg=*/false>( 176 op, op.getNumResults(), index, 177 DictionaryAttr::get(op->getContext(), attributes)); 178 } 179 180 void function_interface_impl::setResultAttrs(FunctionOpInterface op, 181 unsigned index, 182 DictionaryAttr attributes) { 183 assert(index < op.getNumResults() && "invalid result number"); 184 return setArgResAttrDict</*isArg=*/false>( 185 op, op.getNumResults(), index, 186 attributes ? attributes : DictionaryAttr::get(op->getContext())); 187 } 188 189 void function_interface_impl::insertFunctionArguments( 190 FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes, 191 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs, 192 unsigned originalNumArgs, Type newType) { 193 assert(argIndices.size() == argTypes.size()); 194 assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); 195 assert(argIndices.size() == argLocs.size()); 196 if (argIndices.empty()) 197 return; 198 199 // There are 3 things that need to be updated: 200 // - Function type. 201 // - Arg attrs. 202 // - Block arguments of entry block. 203 Block &entry = op->getRegion(0).front(); 204 205 // Update the argument attributes of the function. 206 ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); 207 if (oldArgAttrs || !argAttrs.empty()) { 208 SmallVector<DictionaryAttr, 4> newArgAttrs; 209 newArgAttrs.reserve(originalNumArgs + argIndices.size()); 210 unsigned oldIdx = 0; 211 auto migrate = [&](unsigned untilIdx) { 212 if (!oldArgAttrs) { 213 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); 214 } else { 215 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); 216 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, 217 oldArgAttrRange.begin() + untilIdx); 218 } 219 oldIdx = untilIdx; 220 }; 221 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { 222 migrate(argIndices[i]); 223 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); 224 } 225 migrate(originalNumArgs); 226 setAllArgAttrDicts(op, newArgAttrs); 227 } 228 229 // Update the function type and any entry block arguments. 230 op.setFunctionTypeAttr(TypeAttr::get(newType)); 231 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) 232 entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); 233 } 234 235 void function_interface_impl::insertFunctionResults( 236 FunctionOpInterface op, ArrayRef<unsigned> resultIndices, 237 TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs, 238 unsigned originalNumResults, Type newType) { 239 assert(resultIndices.size() == resultTypes.size()); 240 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); 241 if (resultIndices.empty()) 242 return; 243 244 // There are 2 things that need to be updated: 245 // - Function type. 246 // - Result attrs. 247 248 // Update the result attributes of the function. 249 ArrayAttr oldResultAttrs = op.getResAttrsAttr(); 250 if (oldResultAttrs || !resultAttrs.empty()) { 251 SmallVector<DictionaryAttr, 4> newResultAttrs; 252 newResultAttrs.reserve(originalNumResults + resultIndices.size()); 253 unsigned oldIdx = 0; 254 auto migrate = [&](unsigned untilIdx) { 255 if (!oldResultAttrs) { 256 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); 257 } else { 258 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); 259 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, 260 oldResultAttrsRange.begin() + untilIdx); 261 } 262 oldIdx = untilIdx; 263 }; 264 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { 265 migrate(resultIndices[i]); 266 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} 267 : resultAttrs[i]); 268 } 269 migrate(originalNumResults); 270 setAllResultAttrDicts(op, newResultAttrs); 271 } 272 273 // Update the function type. 274 op.setFunctionTypeAttr(TypeAttr::get(newType)); 275 } 276 277 void function_interface_impl::eraseFunctionArguments( 278 FunctionOpInterface op, const BitVector &argIndices, Type newType) { 279 // There are 3 things that need to be updated: 280 // - Function type. 281 // - Arg attrs. 282 // - Block arguments of entry block. 283 Block &entry = op->getRegion(0).front(); 284 285 // Update the argument attributes of the function. 286 if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { 287 SmallVector<DictionaryAttr, 4> newArgAttrs; 288 newArgAttrs.reserve(argAttrs.size()); 289 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) 290 if (!argIndices[i]) 291 newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i])); 292 setAllArgAttrDicts(op, newArgAttrs); 293 } 294 295 // Update the function type and any entry block arguments. 296 op.setFunctionTypeAttr(TypeAttr::get(newType)); 297 entry.eraseArguments(argIndices); 298 } 299 300 void function_interface_impl::eraseFunctionResults( 301 FunctionOpInterface op, const BitVector &resultIndices, Type newType) { 302 // There are 2 things that need to be updated: 303 // - Function type. 304 // - Result attrs. 305 306 // Update the result attributes of the function. 307 if (ArrayAttr resAttrs = op.getResAttrsAttr()) { 308 SmallVector<DictionaryAttr, 4> newResultAttrs; 309 newResultAttrs.reserve(resAttrs.size()); 310 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) 311 if (!resultIndices[i]) 312 newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i])); 313 setAllResultAttrDicts(op, newResultAttrs); 314 } 315 316 // Update the function type. 317 op.setFunctionTypeAttr(TypeAttr::get(newType)); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // Function type signature. 322 //===----------------------------------------------------------------------===// 323 324 void function_interface_impl::setFunctionType(FunctionOpInterface op, 325 Type newType) { 326 unsigned oldNumArgs = op.getNumArguments(); 327 unsigned oldNumResults = op.getNumResults(); 328 op.setFunctionTypeAttr(TypeAttr::get(newType)); 329 unsigned newNumArgs = op.getNumArguments(); 330 unsigned newNumResults = op.getNumResults(); 331 332 // Functor used to update the argument and result attributes of the function. 333 auto emptyDict = DictionaryAttr::get(op.getContext()); 334 auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { 335 constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>; 336 337 if (oldCount == newCount) 338 return; 339 // The new type has no arguments/results, just drop the attribute. 340 if (newCount == 0) 341 return removeArgResAttrs<isArgVal>(op); 342 ArrayAttr attrs = getArgResAttrs<isArgVal>(op); 343 if (!attrs) 344 return; 345 346 // The new type has less arguments/results, take the first N attributes. 347 if (newCount < oldCount) 348 return setAllArgResAttrDicts<isArgVal>( 349 op, attrs.getValue().take_front(newCount)); 350 351 // Otherwise, the new type has more arguments/results. Initialize the new 352 // arguments/results with empty dictionary attributes. 353 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); 354 newAttrs.resize(newCount, emptyDict); 355 setAllArgResAttrDicts<isArgVal>(op, newAttrs); 356 }; 357 358 // Update the argument and result attributes. 359 updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); 360 updateAttrFn(std::false_type{}, oldNumResults, newNumResults); 361 } 362