xref: /llvm-project/mlir/include/mlir/Pass/AnalysisManager.h (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- AnalysisManager.h - Analysis Management Infrastructure ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PASS_ANALYSISMANAGER_H
10 #define MLIR_PASS_ANALYSISMANAGER_H
11 
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Pass/PassInstrumentation.h"
14 #include "mlir/Support/LLVM.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/TypeName.h"
19 #include <optional>
20 
21 namespace mlir {
22 class AnalysisManager;
23 
24 //===----------------------------------------------------------------------===//
25 // Analysis Preservation and Concept Modeling
26 //===----------------------------------------------------------------------===//
27 
28 namespace detail {
29 /// A utility class to represent the analyses that are known to be preserved.
30 class PreservedAnalyses {
31   /// A type used to represent all potential analyses.
32   struct AllAnalysesType {};
33 
34 public:
35   /// Mark all analyses as preserved.
36   void preserveAll() { preservedIDs.insert(TypeID::get<AllAnalysesType>()); }
37 
38   /// Returns true if all analyses were marked preserved.
39   bool isAll() const {
40     return preservedIDs.count(TypeID::get<AllAnalysesType>());
41   }
42 
43   /// Returns true if no analyses were marked preserved.
44   bool isNone() const { return preservedIDs.empty(); }
45 
46   /// Preserve the given analyses.
47   template <typename AnalysisT>
48   void preserve() {
49     preserve(TypeID::get<AnalysisT>());
50   }
51   template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
52   void preserve() {
53     preserve<AnalysisT>();
54     preserve<AnalysisT2, OtherAnalysesT...>();
55   }
56   void preserve(TypeID id) { preservedIDs.insert(id); }
57 
58   /// Returns true if the given analysis has been marked as preserved. Note that
59   /// this simply checks for the presence of a given analysis ID and should not
60   /// be used as a general preservation checker.
61   template <typename AnalysisT>
62   bool isPreserved() const {
63     return isPreserved(TypeID::get<AnalysisT>());
64   }
65   bool isPreserved(TypeID id) const { return preservedIDs.count(id); }
66 
67 private:
68   /// Remove the analysis from preserved set.
69   template <typename AnalysisT>
70   void unpreserve() {
71     preservedIDs.erase(TypeID::get<AnalysisT>());
72   }
73 
74   /// AnalysisModel need access to unpreserve().
75   template <typename>
76   friend struct AnalysisModel;
77 
78   /// The set of analyses that are known to be preserved.
79   SmallPtrSet<TypeID, 2> preservedIDs;
80 };
81 
82 namespace analysis_impl {
83 /// Trait to check if T provides a static 'isInvalidated' method.
84 template <typename T, typename... Args>
85 using has_is_invalidated = decltype(std::declval<T &>().isInvalidated(
86     std::declval<const PreservedAnalyses &>()));
87 
88 /// Implementation of 'isInvalidated' if the analysis provides a definition.
89 template <typename AnalysisT>
90 std::enable_if_t<llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
91 isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
92   return analysis.isInvalidated(pa);
93 }
94 /// Default implementation of 'isInvalidated'.
95 template <typename AnalysisT>
96 std::enable_if_t<!llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
97 isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
98   return !pa.isPreserved<AnalysisT>();
99 }
100 } // namespace analysis_impl
101 
102 /// The abstract polymorphic base class representing an analysis.
103 struct AnalysisConcept {
104   virtual ~AnalysisConcept() = default;
105 
106   /// A hook used to query analyses for invalidation. Given a preserved analysis
107   /// set, returns true if it should truly be invalidated. This allows for more
108   /// fine-tuned invalidation in cases where an analysis wasn't explicitly
109   /// marked preserved, but may be preserved(or invalidated) based upon other
110   /// properties such as analyses sets. Invalidated analyses must also be
111   /// removed from pa.
112   virtual bool invalidate(PreservedAnalyses &pa) = 0;
113 };
114 
115 /// A derived analysis model used to hold a specific analysis object.
116 template <typename AnalysisT>
117 struct AnalysisModel : public AnalysisConcept {
118   template <typename... Args>
119   explicit AnalysisModel(Args &&...args)
120       : analysis(std::forward<Args>(args)...) {}
121 
122   /// A hook used to query analyses for invalidation. Removes invalidated
123   /// analyses from pa.
124   bool invalidate(PreservedAnalyses &pa) final {
125     bool result = analysis_impl::isInvalidated(analysis, pa);
126     if (result)
127       pa.unpreserve<AnalysisT>();
128     return result;
129   }
130 
131   /// The actual analysis object.
132   AnalysisT analysis;
133 };
134 
135 /// This class represents a cache of analyses for a single operation. All
136 /// computation, caching, and invalidation of analyses takes place here.
137 class AnalysisMap {
138   /// A mapping between an analysis id and an existing analysis instance.
139   using ConceptMap = llvm::MapVector<TypeID, std::unique_ptr<AnalysisConcept>>;
140 
141   /// Utility to return the name of the given analysis class.
142   template <typename AnalysisT>
143   static StringRef getAnalysisName() {
144     StringRef name = llvm::getTypeName<AnalysisT>();
145     if (!name.consume_front("mlir::"))
146       name.consume_front("(anonymous namespace)::");
147     return name;
148   }
149 
150 public:
151   explicit AnalysisMap(Operation *ir) : ir(ir) {}
152 
153   /// Get an analysis for the current IR unit, computing it if necessary.
154   template <typename AnalysisT>
155   AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
156     return getAnalysisImpl<AnalysisT, Operation *>(pi, ir, am);
157   }
158 
159   /// Get an analysis for the current IR unit assuming it's of specific derived
160   /// operation type.
161   template <typename AnalysisT, typename OpT>
162   std::enable_if_t<
163       std::is_constructible<AnalysisT, OpT>::value ||
164           std::is_constructible<AnalysisT, OpT, AnalysisManager &>::value,
165       AnalysisT &>
166   getAnalysis(PassInstrumentor *pi, AnalysisManager &am) {
167     return getAnalysisImpl<AnalysisT, OpT>(pi, cast<OpT>(ir), am);
168   }
169 
170   /// Get a cached analysis instance if one exists, otherwise return null.
171   template <typename AnalysisT>
172   std::optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
173     auto res = analyses.find(TypeID::get<AnalysisT>());
174     if (res == analyses.end())
175       return std::nullopt;
176     return {static_cast<AnalysisModel<AnalysisT> &>(*res->second).analysis};
177   }
178 
179   /// Returns the operation that this analysis map represents.
180   Operation *getOperation() const { return ir; }
181 
182   /// Clear any held analyses.
183   void clear() { analyses.clear(); }
184 
185   /// Invalidate any cached analyses based upon the given set of preserved
186   /// analyses.
187   void invalidate(const PreservedAnalyses &pa) {
188     PreservedAnalyses paCopy(pa);
189     // Remove any analyses that were invalidated.
190     // As we are using MapVector, order of insertion is preserved and
191     // dependencies always go before users, so we need only one iteration.
192     analyses.remove_if(
193         [&](auto &val) { return val.second->invalidate(paCopy); });
194   }
195 
196 private:
197   template <typename AnalysisT, typename OpT>
198   AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op,
199                              AnalysisManager &am) {
200     TypeID id = TypeID::get<AnalysisT>();
201 
202     auto it = analyses.find(id);
203     // If we don't have a cached analysis for this operation, compute it
204     // directly and add it to the cache.
205     if (analyses.end() == it) {
206       if (pi)
207         pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
208 
209       bool wasInserted;
210       std::tie(it, wasInserted) =
211           analyses.insert({id, constructAnalysis<AnalysisT>(am, op)});
212       assert(wasInserted);
213 
214       if (pi)
215         pi->runAfterAnalysis(getAnalysisName<AnalysisT>(), id, ir);
216     }
217     return static_cast<AnalysisModel<AnalysisT> &>(*it->second).analysis;
218   }
219 
220   /// Construct analysis using two arguments constructor (OpT, AnalysisManager)
221   template <typename AnalysisT, typename OpT,
222             std::enable_if_t<std::is_constructible<
223                 AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
224   static auto constructAnalysis(AnalysisManager &am, OpT op) {
225     return std::make_unique<AnalysisModel<AnalysisT>>(op, am);
226   }
227 
228   /// Construct analysis using single argument constructor (OpT)
229   template <typename AnalysisT, typename OpT,
230             std::enable_if_t<!std::is_constructible<
231                 AnalysisT, OpT, AnalysisManager &>::value> * = nullptr>
232   static auto constructAnalysis(AnalysisManager &, OpT op) {
233     return std::make_unique<AnalysisModel<AnalysisT>>(op);
234   }
235 
236   Operation *ir;
237   ConceptMap analyses;
238 };
239 
240 /// An analysis map that contains a map for the current operation, and a set of
241 /// maps for any child operations.
242 struct NestedAnalysisMap {
243   NestedAnalysisMap(Operation *op, PassInstrumentor *instrumentor)
244       : analyses(op), parentOrInstrumentor(instrumentor) {}
245   NestedAnalysisMap(Operation *op, NestedAnalysisMap *parent)
246       : analyses(op), parentOrInstrumentor(parent) {}
247 
248   /// Get the operation for this analysis map.
249   Operation *getOperation() const { return analyses.getOperation(); }
250 
251   /// Invalidate any non preserved analyses.
252   void invalidate(const PreservedAnalyses &pa);
253 
254   /// Returns the parent analysis map for this analysis map, or null if this is
255   /// the top-level map.
256   const NestedAnalysisMap *getParent() const {
257     return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
258   }
259 
260   /// Returns a pass instrumentation object for the current operation. This
261   /// value may be null.
262   PassInstrumentor *getPassInstrumentor() const {
263     if (auto *parent = getParent())
264       return parent->getPassInstrumentor();
265     return cast<PassInstrumentor *>(parentOrInstrumentor);
266   }
267 
268   /// The cached analyses for nested operations.
269   DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
270 
271   /// The analyses for the owning operation.
272   detail::AnalysisMap analyses;
273 
274   /// This value has three possible states:
275   /// NestedAnalysisMap*: A pointer to the parent analysis map.
276   /// PassInstrumentor*: This analysis map is the top-level map, and this
277   ///                    pointer is the optional pass instrumentor for the
278   ///                    current compilation.
279   /// nullptr: This analysis map is the top-level map, and there is nop pass
280   ///          instrumentor.
281   PointerUnion<NestedAnalysisMap *, PassInstrumentor *> parentOrInstrumentor;
282 };
283 } // namespace detail
284 
285 //===----------------------------------------------------------------------===//
286 // Analysis Management
287 //===----------------------------------------------------------------------===//
288 class ModuleAnalysisManager;
289 
290 /// This class represents an analysis manager for a particular operation
291 /// instance. It is used to manage and cache analyses on the operation as well
292 /// as those for child operations, via nested AnalysisManager instances
293 /// accessible via 'slice'. This class is intended to be passed around by value,
294 /// and cannot be constructed directly.
295 class AnalysisManager {
296   using ParentPointerT =
297       PointerUnion<const ModuleAnalysisManager *, const AnalysisManager *>;
298 
299 public:
300   using PreservedAnalyses = detail::PreservedAnalyses;
301 
302   /// Query for a cached analysis on the given parent operation. The analysis
303   /// may not exist and if it does it may be out-of-date.
304   template <typename AnalysisT>
305   std::optional<std::reference_wrapper<AnalysisT>>
306   getCachedParentAnalysis(Operation *parentOp) const {
307     const detail::NestedAnalysisMap *curParent = impl;
308     while (auto *parentAM = curParent->getParent()) {
309       if (parentAM->getOperation() == parentOp)
310         return parentAM->analyses.getCachedAnalysis<AnalysisT>();
311       curParent = parentAM;
312     }
313     return std::nullopt;
314   }
315 
316   /// Query for the given analysis for the current operation.
317   template <typename AnalysisT>
318   AnalysisT &getAnalysis() {
319     return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor(), *this);
320   }
321 
322   /// Query for the given analysis for the current operation of a specific
323   /// derived operation type.
324   template <typename AnalysisT, typename OpT>
325   AnalysisT &getAnalysis() {
326     return impl->analyses.getAnalysis<AnalysisT, OpT>(getPassInstrumentor(),
327                                                       *this);
328   }
329 
330   /// Query for a cached entry of the given analysis on the current operation.
331   template <typename AnalysisT>
332   std::optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
333     return impl->analyses.getCachedAnalysis<AnalysisT>();
334   }
335 
336   /// Query for an analysis of a child operation, constructing it if necessary.
337   template <typename AnalysisT>
338   AnalysisT &getChildAnalysis(Operation *op) {
339     return nest(op).template getAnalysis<AnalysisT>();
340   }
341 
342   /// Query for an analysis of a child operation of a specific derived operation
343   /// type, constructing it if necessary.
344   template <typename AnalysisT, typename OpT>
345   AnalysisT &getChildAnalysis(OpT child) {
346     return nest(child).template getAnalysis<AnalysisT, OpT>();
347   }
348 
349   /// Query for a cached analysis of a child operation, or return null.
350   template <typename AnalysisT>
351   std::optional<std::reference_wrapper<AnalysisT>>
352   getCachedChildAnalysis(Operation *op) const {
353     assert(op->getParentOp() == impl->getOperation());
354     auto it = impl->childAnalyses.find(op);
355     if (it == impl->childAnalyses.end())
356       return std::nullopt;
357     return it->second->analyses.getCachedAnalysis<AnalysisT>();
358   }
359 
360   /// Get an analysis manager for the given operation, which must be a proper
361   /// descendant of the current operation represented by this analysis manager.
362   AnalysisManager nest(Operation *op);
363 
364   /// Invalidate any non preserved analyses,
365   void invalidate(const PreservedAnalyses &pa) { impl->invalidate(pa); }
366 
367   /// Clear any held analyses.
368   void clear() {
369     impl->analyses.clear();
370     impl->childAnalyses.clear();
371   }
372 
373   /// Returns a pass instrumentation object for the current operation. This
374   /// value may be null.
375   PassInstrumentor *getPassInstrumentor() const {
376     return impl->getPassInstrumentor();
377   }
378 
379 private:
380   AnalysisManager(detail::NestedAnalysisMap *impl) : impl(impl) {}
381 
382   /// Get an analysis manager for the given immediately nested child operation.
383   AnalysisManager nestImmediate(Operation *op);
384 
385   /// A reference to the impl analysis map within the parent analysis manager.
386   detail::NestedAnalysisMap *impl;
387 
388   /// Allow access to the constructor.
389   friend class ModuleAnalysisManager;
390 };
391 
392 /// An analysis manager class specifically for the top-level operation. This
393 /// class contains the memory allocations for all nested analysis managers, and
394 /// provides an anchor point. This is necessary because AnalysisManager is
395 /// designed to be a thin wrapper around an existing analysis map instance.
396 class ModuleAnalysisManager {
397 public:
398   ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
399       : analyses(op, passInstrumentor) {}
400   ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
401   ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
402 
403   /// Returns an analysis manager for the current top-level module.
404   operator AnalysisManager() { return AnalysisManager(&analyses); }
405 
406 private:
407   /// The analyses for the owning module.
408   detail::NestedAnalysisMap analyses;
409 };
410 
411 } // namespace mlir
412 
413 #endif // MLIR_PASS_ANALYSISMANAGER_H
414