1 //===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains utilities for querying the sub elements of an attribute or 10 // type. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H 15 #define MLIR_IR_ATTRTYPESUBELEMENTS_H 16 17 #include "mlir/IR/MLIRContext.h" 18 #include "mlir/IR/Visitors.h" 19 #include "mlir/Support/CyclicReplacerCache.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include <optional> 23 24 namespace mlir { 25 class Attribute; 26 class Type; 27 28 //===----------------------------------------------------------------------===// 29 /// AttrTypeWalker 30 //===----------------------------------------------------------------------===// 31 32 /// This class provides a utility for walking attributes/types, and their sub 33 /// elements. Multiple walk functions may be registered. 34 class AttrTypeWalker { 35 public: 36 //===--------------------------------------------------------------------===// 37 // Application 38 //===--------------------------------------------------------------------===// 39 40 /// Walk the given attribute/type, and recursively walk any sub elements. 41 template <WalkOrder Order, typename T> 42 WalkResult walk(T element) { 43 return walkImpl(element, Order); 44 } 45 template <typename T> 46 WalkResult walk(T element) { 47 return walk<WalkOrder::PostOrder, T>(element); 48 } 49 50 //===--------------------------------------------------------------------===// 51 // Registration 52 //===--------------------------------------------------------------------===// 53 54 template <typename T> 55 using WalkFn = std::function<WalkResult(T)>; 56 57 /// Register a walk function for a given attribute or type. A walk function 58 /// must be convertible to any of the following forms(where `T` is a class 59 /// derived from `Type` or `Attribute`: 60 /// 61 /// * WalkResult(T) 62 /// - Returns a walk result, which can be used to control the walk 63 /// 64 /// * void(T) 65 /// - Returns void, i.e. the walk always continues. 66 /// 67 /// Note: When walking, the mostly recently added walk functions will be 68 /// invoked first. 69 void addWalk(WalkFn<Attribute> &&fn) { 70 attrWalkFns.emplace_back(std::move(fn)); 71 } 72 void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(std::move(fn)); } 73 74 /// Register a replacement function that doesn't match the default signature, 75 /// either because it uses a derived parameter type, or it uses a simplified 76 /// result type. 77 template <typename FnT, 78 typename T = typename llvm::function_traits< 79 std::decay_t<FnT>>::template arg_t<0>, 80 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>, 81 Attribute, Type>, 82 typename ResultT = std::invoke_result_t<FnT, T>> 83 std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>> 84 addWalk(FnT &&callback) { 85 addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult { 86 if (auto derived = dyn_cast<T>(base)) { 87 if constexpr (std::is_convertible_v<ResultT, WalkResult>) 88 return callback(derived); 89 else 90 callback(derived); 91 } 92 return WalkResult::advance(); 93 }); 94 } 95 96 private: 97 WalkResult walkImpl(Attribute attr, WalkOrder order); 98 WalkResult walkImpl(Type type, WalkOrder order); 99 100 /// Internal implementation of the `walk` methods above. 101 template <typename T, typename WalkFns> 102 WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order); 103 104 /// Walk the sub elements of the given interface. 105 template <typename T> 106 WalkResult walkSubElements(T interface, WalkOrder order); 107 108 /// The set of walk functions that map sub elements. 109 std::vector<WalkFn<Attribute>> attrWalkFns; 110 std::vector<WalkFn<Type>> typeWalkFns; 111 112 /// The set of visited attributes/types. 113 DenseMap<std::pair<const void *, int>, WalkResult> visitedAttrTypes; 114 }; 115 116 //===----------------------------------------------------------------------===// 117 /// AttrTypeReplacer 118 //===----------------------------------------------------------------------===// 119 120 namespace detail { 121 122 /// This class provides a base utility for replacing attributes/types, and their 123 /// sub elements. Multiple replacement functions may be registered. 124 /// 125 /// This base utility is uncached. Users can choose between two cached versions 126 /// of this replacer: 127 /// * For non-cyclic replacer logic, use `AttrTypeReplacer`. 128 /// * For cyclic replacer logic, use `CyclicAttrTypeReplacer`. 129 /// 130 /// Concrete implementations implement the following `replace` entry functions: 131 /// * Attribute replace(Attribute attr); 132 /// * Type replace(Type type); 133 template <typename Concrete> 134 class AttrTypeReplacerBase { 135 public: 136 //===--------------------------------------------------------------------===// 137 // Application 138 //===--------------------------------------------------------------------===// 139 140 /// Replace the elements within the given operation. If `replaceAttrs` is 141 /// true, this updates the attribute dictionary of the operation. If 142 /// `replaceLocs` is true, this also updates its location, and the locations 143 /// of any nested block arguments. If `replaceTypes` is true, this also 144 /// updates the result types of the operation, and the types of any nested 145 /// block arguments. 146 void replaceElementsIn(Operation *op, bool replaceAttrs = true, 147 bool replaceLocs = false, bool replaceTypes = false); 148 149 /// Replace the elements within the given operation, and all nested 150 /// operations. 151 void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs = true, 152 bool replaceLocs = false, 153 bool replaceTypes = false); 154 155 //===--------------------------------------------------------------------===// 156 // Registration 157 //===--------------------------------------------------------------------===// 158 159 /// A replacement mapping function, which returns either std::nullopt (to 160 /// signal the element wasn't handled), or a pair of the replacement element 161 /// and a WalkResult. 162 template <typename T> 163 using ReplaceFnResult = std::optional<std::pair<T, WalkResult>>; 164 template <typename T> 165 using ReplaceFn = std::function<ReplaceFnResult<T>(T)>; 166 167 /// Register a replacement function for mapping a given attribute or type. A 168 /// replacement function must be convertible to any of the following 169 /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT` 170 /// is either `Type` or `Attribute` respectively): 171 /// 172 /// * std::optional<BaseT>(T) 173 /// - This either returns a valid Attribute/Type in the case of success, 174 /// nullptr in the case of failure, or `std::nullopt` to signify that 175 /// additional replacement functions may be applied (i.e. this function 176 /// doesn't handle that instance). 177 /// 178 /// * std::optional<std::pair<BaseT, WalkResult>>(T) 179 /// - Similar to the above, but also allows specifying a WalkResult to 180 /// control the replacement of sub elements of a given attribute or 181 /// type. Returning a `skip` result, for example, will not recursively 182 /// process the resultant attribute or type value. 183 /// 184 /// Note: When replacing, the mostly recently added replacement functions will 185 /// be invoked first. 186 void addReplacement(ReplaceFn<Attribute> fn); 187 void addReplacement(ReplaceFn<Type> fn); 188 189 /// Register a replacement function that doesn't match the default signature, 190 /// either because it uses a derived parameter type, or it uses a simplified 191 /// result type. 192 template <typename FnT, 193 typename T = typename llvm::function_traits< 194 std::decay_t<FnT>>::template arg_t<0>, 195 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>, 196 Attribute, Type>, 197 typename ResultT = std::invoke_result_t<FnT, T>> 198 std::enable_if_t<!std::is_same_v<T, BaseT> || 199 !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>> 200 addReplacement(FnT &&callback) { 201 addReplacement([callback = std::forward<FnT>(callback)]( 202 BaseT base) -> ReplaceFnResult<BaseT> { 203 if (auto derived = dyn_cast<T>(base)) { 204 if constexpr (std::is_convertible_v<ResultT, std::optional<BaseT>>) { 205 std::optional<BaseT> result = callback(derived); 206 return result ? std::make_pair(*result, WalkResult::advance()) 207 : ReplaceFnResult<BaseT>(); 208 } else { 209 return callback(derived); 210 } 211 } 212 return ReplaceFnResult<BaseT>(); 213 }); 214 } 215 216 protected: 217 /// Invokes the registered replacement functions from most recently registered 218 /// to least recently registered until a successful replacement is returned. 219 /// Unless skipping is requested, invokes `replace` on sub-elements of the 220 /// current attr/type. 221 Attribute replaceBase(Attribute attr); 222 Type replaceBase(Type type); 223 224 private: 225 /// The set of replacement functions that map sub elements. 226 std::vector<ReplaceFn<Attribute>> attrReplacementFns; 227 std::vector<ReplaceFn<Type>> typeReplacementFns; 228 }; 229 230 } // namespace detail 231 232 /// This is an attribute/type replacer that is naively cached. It is best used 233 /// when the replacer logic is guaranteed to not contain cycles. Otherwise, any 234 /// re-occurrence of an in-progress element will be skipped. 235 class AttrTypeReplacer : public detail::AttrTypeReplacerBase<AttrTypeReplacer> { 236 public: 237 Attribute replace(Attribute attr); 238 Type replace(Type type); 239 240 private: 241 /// Shared concrete implementation of the public `replace` functions. Invokes 242 /// `replaceBase` with caching. 243 template <typename T> 244 T cachedReplaceImpl(T element); 245 246 // Stores the opaque pointer of an attribute or type. 247 DenseMap<const void *, const void *> cache; 248 }; 249 250 /// This is an attribute/type replacer that supports custom handling of cycles 251 /// in the replacer logic. In addition to registering replacer functions, it 252 /// allows registering cycle-breaking functions in the same style. 253 class CyclicAttrTypeReplacer 254 : public detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer> { 255 public: 256 CyclicAttrTypeReplacer(); 257 258 //===--------------------------------------------------------------------===// 259 // Application 260 //===--------------------------------------------------------------------===// 261 262 Attribute replace(Attribute attr); 263 Type replace(Type type); 264 265 //===--------------------------------------------------------------------===// 266 // Registration 267 //===--------------------------------------------------------------------===// 268 269 /// A cycle-breaking function. This is invoked if the same element is asked to 270 /// be replaced again when the first instance of it is still being replaced. 271 /// This function must not perform any more recursive `replace` calls. 272 /// If it is able to break the cycle, it should return a replacement result. 273 /// Otherwise, it can return std::nullopt to defer cycle breaking to the next 274 /// repeated element. However, the user must guarantee that, in any possible 275 /// cycle, there always exists at least one element that can break the cycle. 276 template <typename T> 277 using CycleBreakerFn = std::function<std::optional<T>(T)>; 278 279 /// Register a cycle-breaking function. 280 /// When breaking cycles, the mostly recently added cycle-breaking functions 281 /// will be invoked first. 282 void addCycleBreaker(CycleBreakerFn<Attribute> fn); 283 void addCycleBreaker(CycleBreakerFn<Type> fn); 284 285 /// Register a cycle-breaking function that doesn't match the default 286 /// signature. 287 template <typename FnT, 288 typename T = typename llvm::function_traits< 289 std::decay_t<FnT>>::template arg_t<0>, 290 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>, 291 Attribute, Type>> 292 std::enable_if_t<!std::is_same_v<T, BaseT>> addCycleBreaker(FnT &&callback) { 293 addCycleBreaker([callback = std::forward<FnT>(callback)]( 294 BaseT base) -> std::optional<BaseT> { 295 if (auto derived = dyn_cast<T>(base)) 296 return callback(derived); 297 return std::nullopt; 298 }); 299 } 300 301 private: 302 /// Invokes the registered cycle-breaker functions from most recently 303 /// registered to least recently registered until a successful result is 304 /// returned. 305 std::optional<const void *> breakCycleImpl(void *element); 306 307 /// Shared concrete implementation of the public `replace` functions. 308 template <typename T> 309 T cachedReplaceImpl(T element); 310 311 /// The set of registered cycle-breaker functions. 312 std::vector<CycleBreakerFn<Attribute>> attrCycleBreakerFns; 313 std::vector<CycleBreakerFn<Type>> typeCycleBreakerFns; 314 315 /// A cache of previously-replaced attr/types. 316 /// The key of the cache is the opaque value of an AttrOrType. Using 317 /// AttrOrType allows distinguishing between the two types when invoking 318 /// cycle-breakers. Using its opaque value avoids the cyclic dependency issue 319 /// of directly using `AttrOrType` to instantiate the cache. 320 /// The value of the cache is just the opaque value of the attr/type itself 321 /// (not the PointerUnion). 322 using AttrOrType = PointerUnion<Attribute, Type>; 323 CyclicReplacerCache<void *, const void *> cache; 324 }; 325 326 //===----------------------------------------------------------------------===// 327 /// AttrTypeSubElementHandler 328 //===----------------------------------------------------------------------===// 329 330 /// This class is used by AttrTypeSubElementHandler instances to walking sub 331 /// attributes and types. 332 class AttrTypeImmediateSubElementWalker { 333 public: 334 AttrTypeImmediateSubElementWalker(function_ref<void(Attribute)> walkAttrsFn, 335 function_ref<void(Type)> walkTypesFn) 336 : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {} 337 338 /// Walk an attribute. 339 void walk(Attribute element); 340 /// Walk a type. 341 void walk(Type element); 342 /// Walk a range of attributes or types. 343 template <typename RangeT> 344 void walkRange(RangeT &&elements) { 345 for (auto element : elements) 346 walk(element); 347 } 348 349 private: 350 function_ref<void(Attribute)> walkAttrsFn; 351 function_ref<void(Type)> walkTypesFn; 352 }; 353 354 /// This class is used by AttrTypeSubElementHandler instances to process sub 355 /// element replacements. 356 template <typename T> 357 class AttrTypeSubElementReplacements { 358 public: 359 AttrTypeSubElementReplacements(ArrayRef<T> repls) : repls(repls) {} 360 361 /// Take the first N replacements as an ArrayRef, dropping them from 362 /// this replacement list. 363 ArrayRef<T> take_front(unsigned n) { 364 ArrayRef<T> elements = repls.take_front(n); 365 repls = repls.drop_front(n); 366 return elements; 367 } 368 369 private: 370 /// The current set of replacements. 371 ArrayRef<T> repls; 372 }; 373 using AttrSubElementReplacements = AttrTypeSubElementReplacements<Attribute>; 374 using TypeSubElementReplacements = AttrTypeSubElementReplacements<Type>; 375 376 /// This class provides support for interacting with the 377 /// SubElementInterfaces for different types of parameters. An 378 /// implementation of this class should be provided for any parameter class 379 /// that may contain an attribute or type. There are two main methods of 380 /// this class that need to be implemented: 381 /// 382 /// - walk 383 /// 384 /// This method should traverse into any sub elements of the parameter 385 /// using the provided walker, or by invoking handlers for sub-types. 386 /// 387 /// - replace 388 /// 389 /// This method should extract any necessary sub elements using the 390 /// provided replacer, or by invoking handlers for sub-types. The new 391 /// post-replacement parameter value should be returned. 392 /// 393 template <typename T, typename Enable = void> 394 struct AttrTypeSubElementHandler { 395 /// Default walk implementation that does nothing. 396 static inline void walk(const T ¶m, 397 AttrTypeImmediateSubElementWalker &walker) {} 398 399 /// Default replace implementation just forwards the parameter. 400 template <typename ParamT> 401 static inline decltype(auto) replace(ParamT &¶m, 402 AttrSubElementReplacements &attrRepls, 403 TypeSubElementReplacements &typeRepls) { 404 return std::forward<ParamT>(param); 405 } 406 407 /// Tag indicating that this handler does not support sub-elements. 408 using DefaultHandlerTag = void; 409 }; 410 411 /// Detect if any of the given parameter types has a sub-element handler. 412 namespace detail { 413 template <typename T> 414 using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag); 415 } // namespace detail 416 template <typename... Ts> 417 inline constexpr bool has_sub_attr_or_type_v = 418 (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value || 419 ...); 420 421 /// Implementation for derived Attributes and Types. 422 template <typename T> 423 struct AttrTypeSubElementHandler< 424 T, std::enable_if_t<std::is_base_of_v<Attribute, T> || 425 std::is_base_of_v<Type, T>>> { 426 static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { 427 walker.walk(param); 428 } 429 static T replace(T param, AttrSubElementReplacements &attrRepls, 430 TypeSubElementReplacements &typeRepls) { 431 if (!param) 432 return T(); 433 if constexpr (std::is_base_of_v<Attribute, T>) { 434 return cast<T>(attrRepls.take_front(1)[0]); 435 } else { 436 return cast<T>(typeRepls.take_front(1)[0]); 437 } 438 } 439 }; 440 /// Implementation for derived ArrayRef. 441 template <typename T> 442 struct AttrTypeSubElementHandler<ArrayRef<T>, 443 std::enable_if_t<has_sub_attr_or_type_v<T>>> { 444 using EltHandler = AttrTypeSubElementHandler<T>; 445 446 static void walk(ArrayRef<T> param, 447 AttrTypeImmediateSubElementWalker &walker) { 448 for (const T &subElement : param) 449 EltHandler::walk(subElement, walker); 450 } 451 static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls, 452 TypeSubElementReplacements &typeRepls) { 453 // Normal attributes/types can extract using the replacer directly. 454 if constexpr (std::is_base_of_v<Attribute, T> && 455 sizeof(T) == sizeof(void *)) { 456 ArrayRef<Attribute> attrs = attrRepls.take_front(param.size()); 457 return ArrayRef<T>((const T *)attrs.data(), attrs.size()); 458 } else if constexpr (std::is_base_of_v<Type, T> && 459 sizeof(T) == sizeof(void *)) { 460 ArrayRef<Type> types = typeRepls.take_front(param.size()); 461 return ArrayRef<T>((const T *)types.data(), types.size()); 462 } else { 463 // Otherwise, we need to allocate storage for the new elements. 464 SmallVector<T> newElements; 465 for (const T &element : param) 466 newElements.emplace_back( 467 EltHandler::replace(element, attrRepls, typeRepls)); 468 return newElements; 469 } 470 } 471 }; 472 /// Implementation for Tuple. 473 template <typename... Ts> 474 struct AttrTypeSubElementHandler< 475 std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> { 476 static void walk(const std::tuple<Ts...> ¶m, 477 AttrTypeImmediateSubElementWalker &walker) { 478 std::apply( 479 [&](const Ts &...params) { 480 (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...); 481 }, 482 param); 483 } 484 static auto replace(const std::tuple<Ts...> ¶m, 485 AttrSubElementReplacements &attrRepls, 486 TypeSubElementReplacements &typeRepls) { 487 return std::apply( 488 [&](const Ts &...params) 489 -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace( 490 params, attrRepls, typeRepls))...> { 491 return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls, 492 typeRepls)...}; 493 }, 494 param); 495 } 496 }; 497 498 namespace detail { 499 template <typename T> 500 struct is_tuple : public std::false_type {}; 501 template <typename... Ts> 502 struct is_tuple<std::tuple<Ts...>> : public std::true_type {}; 503 504 template <typename T> 505 struct is_pair : public std::false_type {}; 506 template <typename... Ts> 507 struct is_pair<std::pair<Ts...>> : public std::true_type {}; 508 509 template <typename T, typename... Ts> 510 using has_get_method = decltype(T::get(std::declval<Ts>()...)); 511 template <typename T, typename... Ts> 512 using has_get_as_key = decltype(std::declval<T>().getAsKey()); 513 514 /// This function provides the underlying implementation for the 515 /// SubElementInterface walk method, using the key type of the derived 516 /// attribute/type to interact with the individual parameters. 517 template <typename T> 518 void walkImmediateSubElementsImpl(T derived, 519 function_ref<void(Attribute)> walkAttrsFn, 520 function_ref<void(Type)> walkTypesFn) { 521 using ImplT = typename T::ImplType; 522 (void)derived; 523 (void)walkAttrsFn; 524 (void)walkTypesFn; 525 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) { 526 auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey(); 527 528 // If we don't have any sub-elements, there is nothing to do. 529 if constexpr (!has_sub_attr_or_type_v<decltype(key)>) 530 return; 531 AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn); 532 AttrTypeSubElementHandler<decltype(key)>::walk(key, walker); 533 } 534 } 535 536 /// This function invokes the proper `get` method for a type `T` with the given 537 /// values. 538 template <typename T, typename... Ts> 539 auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) { 540 // Prefer a direct `get` method if one exists. 541 if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) { 542 (void)ctx; 543 return T::get(std::forward<Ts>(params)...); 544 } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *, 545 Ts...>::value) { 546 return T::get(ctx, std::forward<Ts>(params)...); 547 } else { 548 // Otherwise, pass to the base get. 549 return T::Base::get(ctx, std::forward<Ts>(params)...); 550 } 551 } 552 553 /// This function provides the underlying implementation for the 554 /// SubElementInterface replace method, using the key type of the derived 555 /// attribute/type to interact with the individual parameters. 556 template <typename T> 557 auto replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs, 558 ArrayRef<Type> &replTypes) { 559 using ImplT = typename T::ImplType; 560 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) { 561 auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey(); 562 563 // If we don't have any sub-elements, we can just return the original. 564 if constexpr (!has_sub_attr_or_type_v<decltype(key)>) { 565 return derived; 566 567 // Otherwise, we need to replace any necessary sub-elements. 568 } else { 569 // Functor used to build the replacement on success. 570 auto buildReplacement = [&](auto newKey, MLIRContext *ctx) { 571 if constexpr (is_tuple<decltype(key)>::value || 572 is_pair<decltype(key)>::value) { 573 return std::apply( 574 [&](auto &&...params) { 575 return constructSubElementReplacement<T>( 576 ctx, std::forward<decltype(params)>(params)...); 577 }, 578 newKey); 579 } else { 580 return constructSubElementReplacement<T>(ctx, newKey); 581 } 582 }; 583 584 AttrSubElementReplacements attrRepls(replAttrs); 585 TypeSubElementReplacements typeRepls(replTypes); 586 auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace( 587 key, attrRepls, typeRepls); 588 MLIRContext *ctx = derived.getContext(); 589 if constexpr (std::is_convertible_v<decltype(newKey), LogicalResult>) 590 return succeeded(newKey) ? buildReplacement(*newKey, ctx) : nullptr; 591 else 592 return buildReplacement(newKey, ctx); 593 } 594 } else { 595 return derived; 596 } 597 } 598 } // namespace detail 599 } // namespace mlir 600 601 #endif // MLIR_IR_ATTRTYPESUBELEMENTS_H 602