1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Operation.h" 10 #include <optional> 11 12 using namespace mlir; 13 14 //===----------------------------------------------------------------------===// 15 // AttrTypeWalker 16 //===----------------------------------------------------------------------===// 17 18 WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { 19 return walkImpl(attr, attrWalkFns, order); 20 } 21 WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { 22 return walkImpl(type, typeWalkFns, order); 23 } 24 25 template <typename T, typename WalkFns> 26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, 27 WalkOrder order) { 28 // Check if we've already walk this element before. 29 auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); 30 auto [it, inserted] = 31 visitedAttrTypes.try_emplace(key, WalkResult::advance()); 32 if (!inserted) 33 return it->second; 34 35 // If we are walking in post order, walk the sub elements first. 36 if (order == WalkOrder::PostOrder) { 37 if (walkSubElements(element, order).wasInterrupted()) 38 return visitedAttrTypes[key] = WalkResult::interrupt(); 39 } 40 41 // Walk this element, bailing if skipped or interrupted. 42 for (auto &walkFn : llvm::reverse(walkFns)) { 43 WalkResult walkResult = walkFn(element); 44 if (walkResult.wasInterrupted()) 45 return visitedAttrTypes[key] = WalkResult::interrupt(); 46 if (walkResult.wasSkipped()) 47 return WalkResult::advance(); 48 } 49 50 // If we are walking in pre-order, walk the sub elements last. 51 if (order == WalkOrder::PreOrder) { 52 if (walkSubElements(element, order).wasInterrupted()) 53 return WalkResult::interrupt(); 54 } 55 return WalkResult::advance(); 56 } 57 58 template <typename T> 59 WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { 60 WalkResult result = WalkResult::advance(); 61 auto walkFn = [&](auto element) { 62 if (element && !result.wasInterrupted()) 63 result = walkImpl(element, order); 64 }; 65 interface.walkImmediateSubElements(walkFn, walkFn); 66 return result.wasInterrupted() ? result : WalkResult::advance(); 67 } 68 69 //===----------------------------------------------------------------------===// 70 /// AttrTypeReplacerBase 71 //===----------------------------------------------------------------------===// 72 73 template <typename Concrete> 74 void detail::AttrTypeReplacerBase<Concrete>::addReplacement( 75 ReplaceFn<Attribute> fn) { 76 attrReplacementFns.emplace_back(std::move(fn)); 77 } 78 79 template <typename Concrete> 80 void detail::AttrTypeReplacerBase<Concrete>::addReplacement( 81 ReplaceFn<Type> fn) { 82 typeReplacementFns.push_back(std::move(fn)); 83 } 84 85 template <typename Concrete> 86 void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn( 87 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { 88 // Functor that replaces the given element if the new value is different, 89 // otherwise returns nullptr. 90 auto replaceIfDifferent = [&](auto element) { 91 auto replacement = static_cast<Concrete *>(this)->replace(element); 92 return (replacement && replacement != element) ? replacement : nullptr; 93 }; 94 95 // Update the attribute dictionary. 96 if (replaceAttrs) { 97 if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) 98 op->setAttrs(cast<DictionaryAttr>(newAttrs)); 99 } 100 101 // If we aren't updating locations or types, we're done. 102 if (!replaceTypes && !replaceLocs) 103 return; 104 105 // Update the location. 106 if (replaceLocs) { 107 if (Attribute newLoc = replaceIfDifferent(op->getLoc())) 108 op->setLoc(cast<LocationAttr>(newLoc)); 109 } 110 111 // Update the result types. 112 if (replaceTypes) { 113 for (OpResult result : op->getResults()) 114 if (Type newType = replaceIfDifferent(result.getType())) 115 result.setType(newType); 116 } 117 118 // Update any nested block arguments. 119 for (Region ®ion : op->getRegions()) { 120 for (Block &block : region) { 121 for (BlockArgument &arg : block.getArguments()) { 122 if (replaceLocs) { 123 if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) 124 arg.setLoc(cast<LocationAttr>(newLoc)); 125 } 126 127 if (replaceTypes) { 128 if (Type newType = replaceIfDifferent(arg.getType())) 129 arg.setType(newType); 130 } 131 } 132 } 133 } 134 } 135 136 template <typename Concrete> 137 void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn( 138 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { 139 op->walk([&](Operation *nestedOp) { 140 replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes); 141 }); 142 } 143 144 template <typename T, typename Replacer> 145 static void updateSubElementImpl(T element, Replacer &replacer, 146 SmallVectorImpl<T> &newElements, 147 FailureOr<bool> &changed) { 148 // Bail early if we failed at any point. 149 if (failed(changed)) 150 return; 151 152 // Guard against potentially null inputs. We always map null to null. 153 if (!element) { 154 newElements.push_back(nullptr); 155 return; 156 } 157 158 // Replace the element. 159 if (T result = replacer.replace(element)) { 160 newElements.push_back(result); 161 if (result != element) 162 changed = true; 163 } else { 164 changed = failure(); 165 } 166 } 167 168 template <typename T, typename Replacer> 169 static T replaceSubElements(T interface, Replacer &replacer) { 170 // Walk the current sub-elements, replacing them as necessary. 171 SmallVector<Attribute, 16> newAttrs; 172 SmallVector<Type, 16> newTypes; 173 FailureOr<bool> changed = false; 174 interface.walkImmediateSubElements( 175 [&](Attribute element) { 176 updateSubElementImpl(element, replacer, newAttrs, changed); 177 }, 178 [&](Type element) { 179 updateSubElementImpl(element, replacer, newTypes, changed); 180 }); 181 if (failed(changed)) 182 return nullptr; 183 184 // If any sub-elements changed, use the new elements during the replacement. 185 T result = interface; 186 if (*changed) 187 result = interface.replaceImmediateSubElements(newAttrs, newTypes); 188 return result; 189 } 190 191 /// Shared implementation of replacing a given attribute or type element. 192 template <typename T, typename ReplaceFns, typename Replacer> 193 static T replaceElementImpl(T element, ReplaceFns &replaceFns, 194 Replacer &replacer) { 195 T result = element; 196 WalkResult walkResult = WalkResult::advance(); 197 for (auto &replaceFn : llvm::reverse(replaceFns)) { 198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) { 199 std::tie(result, walkResult) = *newRes; 200 break; 201 } 202 } 203 204 // If an error occurred, return nullptr to indicate failure. 205 if (walkResult.wasInterrupted() || !result) { 206 return nullptr; 207 } 208 209 // Handle replacing sub-elements if this element is also a container. 210 if (!walkResult.wasSkipped()) { 211 // Replace the sub elements of this element, bailing if we fail. 212 if (!(result = replaceSubElements(result, replacer))) { 213 return nullptr; 214 } 215 } 216 217 return result; 218 } 219 220 template <typename Concrete> 221 Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) { 222 return replaceElementImpl(attr, attrReplacementFns, 223 *static_cast<Concrete *>(this)); 224 } 225 226 template <typename Concrete> 227 Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) { 228 return replaceElementImpl(type, typeReplacementFns, 229 *static_cast<Concrete *>(this)); 230 } 231 232 //===----------------------------------------------------------------------===// 233 /// AttrTypeReplacer 234 //===----------------------------------------------------------------------===// 235 236 template class detail::AttrTypeReplacerBase<AttrTypeReplacer>; 237 238 template <typename T> 239 T AttrTypeReplacer::cachedReplaceImpl(T element) { 240 const void *opaqueElement = element.getAsOpaquePointer(); 241 auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement); 242 if (!inserted) 243 return T::getFromOpaquePointer(it->second); 244 245 T result = replaceBase(element); 246 247 cache[opaqueElement] = result.getAsOpaquePointer(); 248 return result; 249 } 250 251 Attribute AttrTypeReplacer::replace(Attribute attr) { 252 return cachedReplaceImpl(attr); 253 } 254 255 Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); } 256 257 //===----------------------------------------------------------------------===// 258 /// CyclicAttrTypeReplacer 259 //===----------------------------------------------------------------------===// 260 261 template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>; 262 263 CyclicAttrTypeReplacer::CyclicAttrTypeReplacer() 264 : cache([&](void *attr) { return breakCycleImpl(attr); }) {} 265 266 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) { 267 attrCycleBreakerFns.emplace_back(std::move(fn)); 268 } 269 270 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) { 271 typeCycleBreakerFns.emplace_back(std::move(fn)); 272 } 273 274 template <typename T> 275 T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) { 276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue(); 277 CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry = 278 cache.lookupOrInit(opaqueTaggedElement); 279 if (auto resultOpt = cacheEntry.get()) 280 return T::getFromOpaquePointer(*resultOpt); 281 282 T result = replaceBase(element); 283 284 cacheEntry.resolve(result.getAsOpaquePointer()); 285 return result; 286 } 287 288 Attribute CyclicAttrTypeReplacer::replace(Attribute attr) { 289 return cachedReplaceImpl(attr); 290 } 291 292 Type CyclicAttrTypeReplacer::replace(Type type) { 293 return cachedReplaceImpl(type); 294 } 295 296 std::optional<const void *> 297 CyclicAttrTypeReplacer::breakCycleImpl(void *element) { 298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element); 299 if (auto attr = dyn_cast<Attribute>(attrType)) { 300 for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) { 301 if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) { 302 return newRes->getAsOpaquePointer(); 303 } 304 } 305 } else { 306 auto type = dyn_cast<Type>(attrType); 307 for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) { 308 if (std::optional<Type> newRes = cyclicReplaceFn(type)) { 309 return newRes->getAsOpaquePointer(); 310 } 311 } 312 } 313 return std::nullopt; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // AttrTypeImmediateSubElementWalker 318 //===----------------------------------------------------------------------===// 319 320 void AttrTypeImmediateSubElementWalker::walk(Attribute element) { 321 if (element) 322 walkAttrsFn(element); 323 } 324 325 void AttrTypeImmediateSubElementWalker::walk(Type element) { 326 if (element) 327 walkTypesFn(element); 328 } 329