xref: /llvm-project/mlir/lib/IR/AttrTypeSubElements.cpp (revision 01eb071de014759101940096a31d65babc8af04e)
1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include <optional>
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // AttrTypeWalker
16 //===----------------------------------------------------------------------===//
17 
18 WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
19   return walkImpl(attr, attrWalkFns, order);
20 }
21 WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
22   return walkImpl(type, typeWalkFns, order);
23 }
24 
25 template <typename T, typename WalkFns>
26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
27                                     WalkOrder order) {
28   // Check if we've already walk this element before.
29   auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
30   auto [it, inserted] =
31       visitedAttrTypes.try_emplace(key, WalkResult::advance());
32   if (!inserted)
33     return it->second;
34 
35   // If we are walking in post order, walk the sub elements first.
36   if (order == WalkOrder::PostOrder) {
37     if (walkSubElements(element, order).wasInterrupted())
38       return visitedAttrTypes[key] = WalkResult::interrupt();
39   }
40 
41   // Walk this element, bailing if skipped or interrupted.
42   for (auto &walkFn : llvm::reverse(walkFns)) {
43     WalkResult walkResult = walkFn(element);
44     if (walkResult.wasInterrupted())
45       return visitedAttrTypes[key] = WalkResult::interrupt();
46     if (walkResult.wasSkipped())
47       return WalkResult::advance();
48   }
49 
50   // If we are walking in pre-order, walk the sub elements last.
51   if (order == WalkOrder::PreOrder) {
52     if (walkSubElements(element, order).wasInterrupted())
53       return WalkResult::interrupt();
54   }
55   return WalkResult::advance();
56 }
57 
58 template <typename T>
59 WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
60   WalkResult result = WalkResult::advance();
61   auto walkFn = [&](auto element) {
62     if (element && !result.wasInterrupted())
63       result = walkImpl(element, order);
64   };
65   interface.walkImmediateSubElements(walkFn, walkFn);
66   return result.wasInterrupted() ? result : WalkResult::advance();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 /// AttrTypeReplacerBase
71 //===----------------------------------------------------------------------===//
72 
73 template <typename Concrete>
74 void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
75     ReplaceFn<Attribute> fn) {
76   attrReplacementFns.emplace_back(std::move(fn));
77 }
78 
79 template <typename Concrete>
80 void detail::AttrTypeReplacerBase<Concrete>::addReplacement(
81     ReplaceFn<Type> fn) {
82   typeReplacementFns.push_back(std::move(fn));
83 }
84 
85 template <typename Concrete>
86 void detail::AttrTypeReplacerBase<Concrete>::replaceElementsIn(
87     Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
88   // Functor that replaces the given element if the new value is different,
89   // otherwise returns nullptr.
90   auto replaceIfDifferent = [&](auto element) {
91     auto replacement = static_cast<Concrete *>(this)->replace(element);
92     return (replacement && replacement != element) ? replacement : nullptr;
93   };
94 
95   // Update the attribute dictionary.
96   if (replaceAttrs) {
97     if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
98       op->setAttrs(cast<DictionaryAttr>(newAttrs));
99   }
100 
101   // If we aren't updating locations or types, we're done.
102   if (!replaceTypes && !replaceLocs)
103     return;
104 
105   // Update the location.
106   if (replaceLocs) {
107     if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
108       op->setLoc(cast<LocationAttr>(newLoc));
109   }
110 
111   // Update the result types.
112   if (replaceTypes) {
113     for (OpResult result : op->getResults())
114       if (Type newType = replaceIfDifferent(result.getType()))
115         result.setType(newType);
116   }
117 
118   // Update any nested block arguments.
119   for (Region &region : op->getRegions()) {
120     for (Block &block : region) {
121       for (BlockArgument &arg : block.getArguments()) {
122         if (replaceLocs) {
123           if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124             arg.setLoc(cast<LocationAttr>(newLoc));
125         }
126 
127         if (replaceTypes) {
128           if (Type newType = replaceIfDifferent(arg.getType()))
129             arg.setType(newType);
130         }
131       }
132     }
133   }
134 }
135 
136 template <typename Concrete>
137 void detail::AttrTypeReplacerBase<Concrete>::recursivelyReplaceElementsIn(
138     Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
139   op->walk([&](Operation *nestedOp) {
140     replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
141   });
142 }
143 
144 template <typename T, typename Replacer>
145 static void updateSubElementImpl(T element, Replacer &replacer,
146                                  SmallVectorImpl<T> &newElements,
147                                  FailureOr<bool> &changed) {
148   // Bail early if we failed at any point.
149   if (failed(changed))
150     return;
151 
152   // Guard against potentially null inputs. We always map null to null.
153   if (!element) {
154     newElements.push_back(nullptr);
155     return;
156   }
157 
158   // Replace the element.
159   if (T result = replacer.replace(element)) {
160     newElements.push_back(result);
161     if (result != element)
162       changed = true;
163   } else {
164     changed = failure();
165   }
166 }
167 
168 template <typename T, typename Replacer>
169 static T replaceSubElements(T interface, Replacer &replacer) {
170   // Walk the current sub-elements, replacing them as necessary.
171   SmallVector<Attribute, 16> newAttrs;
172   SmallVector<Type, 16> newTypes;
173   FailureOr<bool> changed = false;
174   interface.walkImmediateSubElements(
175       [&](Attribute element) {
176         updateSubElementImpl(element, replacer, newAttrs, changed);
177       },
178       [&](Type element) {
179         updateSubElementImpl(element, replacer, newTypes, changed);
180       });
181   if (failed(changed))
182     return nullptr;
183 
184   // If any sub-elements changed, use the new elements during the replacement.
185   T result = interface;
186   if (*changed)
187     result = interface.replaceImmediateSubElements(newAttrs, newTypes);
188   return result;
189 }
190 
191 /// Shared implementation of replacing a given attribute or type element.
192 template <typename T, typename ReplaceFns, typename Replacer>
193 static T replaceElementImpl(T element, ReplaceFns &replaceFns,
194                             Replacer &replacer) {
195   T result = element;
196   WalkResult walkResult = WalkResult::advance();
197   for (auto &replaceFn : llvm::reverse(replaceFns)) {
198     if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199       std::tie(result, walkResult) = *newRes;
200       break;
201     }
202   }
203 
204   // If an error occurred, return nullptr to indicate failure.
205   if (walkResult.wasInterrupted() || !result) {
206     return nullptr;
207   }
208 
209   // Handle replacing sub-elements if this element is also a container.
210   if (!walkResult.wasSkipped()) {
211     // Replace the sub elements of this element, bailing if we fail.
212     if (!(result = replaceSubElements(result, replacer))) {
213       return nullptr;
214     }
215   }
216 
217   return result;
218 }
219 
220 template <typename Concrete>
221 Attribute detail::AttrTypeReplacerBase<Concrete>::replaceBase(Attribute attr) {
222   return replaceElementImpl(attr, attrReplacementFns,
223                             *static_cast<Concrete *>(this));
224 }
225 
226 template <typename Concrete>
227 Type detail::AttrTypeReplacerBase<Concrete>::replaceBase(Type type) {
228   return replaceElementImpl(type, typeReplacementFns,
229                             *static_cast<Concrete *>(this));
230 }
231 
232 //===----------------------------------------------------------------------===//
233 /// AttrTypeReplacer
234 //===----------------------------------------------------------------------===//
235 
236 template class detail::AttrTypeReplacerBase<AttrTypeReplacer>;
237 
238 template <typename T>
239 T AttrTypeReplacer::cachedReplaceImpl(T element) {
240   const void *opaqueElement = element.getAsOpaquePointer();
241   auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
242   if (!inserted)
243     return T::getFromOpaquePointer(it->second);
244 
245   T result = replaceBase(element);
246 
247   cache[opaqueElement] = result.getAsOpaquePointer();
248   return result;
249 }
250 
251 Attribute AttrTypeReplacer::replace(Attribute attr) {
252   return cachedReplaceImpl(attr);
253 }
254 
255 Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
256 
257 //===----------------------------------------------------------------------===//
258 /// CyclicAttrTypeReplacer
259 //===----------------------------------------------------------------------===//
260 
261 template class detail::AttrTypeReplacerBase<CyclicAttrTypeReplacer>;
262 
263 CyclicAttrTypeReplacer::CyclicAttrTypeReplacer()
264     : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
265 
266 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Attribute> fn) {
267   attrCycleBreakerFns.emplace_back(std::move(fn));
268 }
269 
270 void CyclicAttrTypeReplacer::addCycleBreaker(CycleBreakerFn<Type> fn) {
271   typeCycleBreakerFns.emplace_back(std::move(fn));
272 }
273 
274 template <typename T>
275 T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276   void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
277   CyclicReplacerCache<void *, const void *>::CacheEntry cacheEntry =
278       cache.lookupOrInit(opaqueTaggedElement);
279   if (auto resultOpt = cacheEntry.get())
280     return T::getFromOpaquePointer(*resultOpt);
281 
282   T result = replaceBase(element);
283 
284   cacheEntry.resolve(result.getAsOpaquePointer());
285   return result;
286 }
287 
288 Attribute CyclicAttrTypeReplacer::replace(Attribute attr) {
289   return cachedReplaceImpl(attr);
290 }
291 
292 Type CyclicAttrTypeReplacer::replace(Type type) {
293   return cachedReplaceImpl(type);
294 }
295 
296 std::optional<const void *>
297 CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
298   AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299   if (auto attr = dyn_cast<Attribute>(attrType)) {
300     for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301       if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302         return newRes->getAsOpaquePointer();
303       }
304     }
305   } else {
306     auto type = dyn_cast<Type>(attrType);
307     for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308       if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309         return newRes->getAsOpaquePointer();
310       }
311     }
312   }
313   return std::nullopt;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // AttrTypeImmediateSubElementWalker
318 //===----------------------------------------------------------------------===//
319 
320 void AttrTypeImmediateSubElementWalker::walk(Attribute element) {
321   if (element)
322     walkAttrsFn(element);
323 }
324 
325 void AttrTypeImmediateSubElementWalker::walk(Type element) {
326   if (element)
327     walkTypesFn(element);
328 }
329