xref: /llvm-project/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1 //===- OneShotAnalysis.h - One-Shot (Single Pass) Analysis ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
10 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
11 
12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13 #include "llvm/ADT/EquivalenceClasses.h"
14 #include <string>
15 
16 namespace mlir {
17 class DominanceInfo;
18 
19 namespace bufferization {
20 
21 struct OneShotBufferizationOptions;
22 struct BufferizationStatistics;
23 class OneShotAnalysisState;
24 
25 /// Options for analysis-enabled bufferization.
26 struct OneShotBufferizationOptions : public BufferizationOptions {
27   enum class AnalysisHeuristic {
28     BottomUp,
29     TopDown,
30     BottomUpFromTerminators,
31     Fuzzer
32   };
33 
34   OneShotBufferizationOptions() = default;
35 
36   /// Specifies whether returning newly allocated memrefs from loops should be
37   /// allowed.  Otherwise, a pass failure is triggered.
38   bool allowReturnAllocsFromLoops = false;
39 
40   /// Specifies whether the tensor IR should be annotated with alias sets.
41   bool dumpAliasSets = false;
42 
43   /// The heuristic controls the order in which ops are traversed during the
44   /// analysis.
45   AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp;
46 
47   /// Specify the functions that should not be analyzed. copyBeforeWrite will be
48   /// set to true when bufferizing them.
49   llvm::ArrayRef<std::string> noAnalysisFuncFilter;
50 
51   /// Seed for the analysis fuzzer. Used only if the heuristic is set to
52   /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with
53   /// `testAnalysisOnly = true`.
54   unsigned analysisFuzzerSeed = 0;
55 };
56 
57 /// State for analysis-enabled bufferization. This class keeps track of alias
58 /// sets, equivalence sets, in-place OpOperands and other things.
59 ///
60 /// Note: Modifying the IR generally invalidates the result of the analysis.
61 /// Adding new operations is safe if they are analyzed subsequently.
62 class OneShotAnalysisState : public AnalysisState {
63 public:
64   OneShotAnalysisState(Operation *op,
65                        const OneShotBufferizationOptions &options);
66 
67   OneShotAnalysisState(const OneShotAnalysisState &) = delete;
68 
69   ~OneShotAnalysisState() override = default;
70 
71   static bool classof(const AnalysisState *base) {
72     return base->getType() == TypeID::get<OneShotAnalysisState>();
73   }
74 
75   /// Return a reference to the BufferizationOptions.
76   const OneShotBufferizationOptions &getOptions() const {
77     return static_cast<const OneShotBufferizationOptions &>(
78         AnalysisState::getOptions());
79   }
80 
81   /// Analyze the given op and its nested ops.
82   LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo);
83 
84   /// Analyze a single op (without nested ops).
85   LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo);
86 
87   /// Apply `fun` to all the members of the equivalence class of `v`.
88   void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
89 
90   /// Apply `fun` to all aliases of `v`.
91   void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
92 
93   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
94   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
95 
96   /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
97   bool areAliasingBufferizedValues(Value v1, Value v2) const override;
98 
99   /// Mark the given OpOperand as in-place and merge the results' and operand's
100   /// aliasing sets.
101   void bufferizeInPlace(OpOperand &operand);
102 
103   /// Mark the given OpOperand as out-of-place.
104   void bufferizeOutOfPlace(OpOperand &operand);
105 
106   /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
107   /// beginning the alias and equivalence sets only contain `v` itself.
108   void createAliasInfoEntry(Value v);
109 
110   /// Find all tensor values in the given operation that have undefined contents
111   /// and store them in `undefinedTensorUses`.
112   void gatherUndefinedTensorUses(Operation *op);
113 
114   int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; }
115   int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; }
116 
117   /// Return `true` if the given tensor has undefined contents.
118   bool hasUndefinedContents(OpOperand *opOperand) const override;
119 
120   /// Return `true` if the given OpResult has been decided to bufferize inplace.
121   bool isInPlace(OpOperand &opOperand) const override;
122 
123   /// Return true if the buffer of the given tensor value is written to. Must
124   /// not be called for values inside not yet analyzed functions.
125   bool isValueWritten(Value value) const;
126 
127   /// Return true if the buffer of the given tensor value is writable.
128   bool isWritable(Value value) const;
129 
130   /// Find the definitions of the given operand's value or
131   /// retrieve them from the cache.
132   const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
133 
134   /// Reset cached data structures.
135   void resetCache() override;
136 
137   /// Union the alias sets of `v1` and `v2`.
138   void unionAliasSets(Value v1, Value v2);
139 
140   /// Union the equivalence classes of `v1` and `v2`.
141   void unionEquivalenceClasses(Value v1, Value v2);
142 
143   /// Base class for OneShotAnalysisState extensions that allow
144   /// OneShotAnalysisState to contain user-specified information in the state
145   /// object. Clients are expected to derive this class, add the desired fields,
146   /// and make the derived class compatible with the MLIR TypeID mechanism.
147   ///
148   /// ```mlir
149   /// class MyExtension final : public OneShotAnalysisState::Extension {
150   /// public:
151   ///   MyExtension(OneShotAnalysisState &state, int myData)
152   ///       : Extension(state) {...}
153   /// private:
154   ///   int mySupplementaryData;
155   /// };
156   /// ```
157   ///
158   /// Instances of this and derived classes are not expected to be created by
159   /// the user, instead they are directly constructed within a
160   /// OneShotAnalysisState. A OneShotAnalysisState can only contain one
161   /// extension with the given TypeID. Extensions can be obtained from a
162   /// OneShotAnalysisState instance.
163   ///
164   /// ```mlir
165   /// state.addExtension<MyExtension>(/*myData=*/42);
166   /// MyExtension *ext = state.getExtension<MyExtension>();
167   /// ext->doSomething();
168   /// ```
169   class Extension {
170     // Allow OneShotAnalysisState to allocate Extensions.
171     friend class OneShotAnalysisState;
172 
173   public:
174     /// Base virtual destructor.
175     // Out-of-line definition ensures symbols are emitted in a single object
176     // file.
177     virtual ~Extension();
178 
179   protected:
180     /// Constructs an extension of the given state object.
181     Extension(OneShotAnalysisState &state) : state(state) {}
182 
183     /// Provides read-only access to the parent OneShotAnalysisState object.
184     const OneShotAnalysisState &getAnalysisState() const { return state; }
185 
186   private:
187     /// Back-reference to the state that is being extended.
188     OneShotAnalysisState &state;
189   };
190 
191   /// Adds a new Extension of the type specified as template parameter,
192   /// constructing it with the arguments provided. The extension is owned by the
193   /// OneShotAnalysisState. It is expected that the state does not already have
194   /// an extension of the same type. Extension constructors are expected to take
195   /// a reference to OneShotAnalysisState as first argument, automatically
196   /// supplied by this call.
197   template <typename Ty, typename... Args>
198   Ty &addExtension(Args &&...args) {
199     static_assert(
200         std::is_base_of<Extension, Ty>::value,
201         "only a class derived from OneShotAnalysisState::Extension is allowed");
202     auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
203     auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
204     assert(result.second && "extension already added");
205     return *static_cast<Ty *>(result.first->second.get());
206   }
207 
208   /// Returns the extension of the specified type.
209   template <typename Ty>
210   Ty *getExtension() {
211     static_assert(
212         std::is_base_of<Extension, Ty>::value,
213         "only a class derived from OneShotAnalysisState::Extension is allowed");
214     auto iter = extensions.find(TypeID::get<Ty>());
215     if (iter == extensions.end())
216       return nullptr;
217     return static_cast<Ty *>(iter->second.get());
218   }
219 
220   /// Returns the extension of the specified type.
221   template <typename Ty>
222   const Ty *getExtension() const {
223     return const_cast<OneShotAnalysisState *>(this)->getExtension<Ty>();
224   }
225 
226 private:
227   /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
228   /// pointer comparison on the defining op. This is a poor man's comparison
229   /// but it's not like UnionFind needs ordering anyway.
230   struct ValueComparator {
231     bool operator()(const Value &lhs, const Value &rhs) const {
232       return lhs.getImpl() < rhs.getImpl();
233     }
234   };
235 
236   using EquivalenceClassRangeType = llvm::iterator_range<
237       llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
238   /// Check that aliasInfo for `v` exists and return a reference to it.
239   EquivalenceClassRangeType getAliases(Value v) const;
240 
241   /// Cache definitions of tensor values.
242   DenseMap<Value, SetVector<Value>> cachedDefinitions;
243 
244   /// Set of all OpResults that were decided to bufferize in-place.
245   llvm::DenseSet<OpOperand *> inplaceBufferized;
246 
247   /// Auxiliary structure to store all the values a given value may alias with.
248   /// Alias information is "may be" conservative: In the presence of branches, a
249   /// value may alias with one of multiple other values. The concrete aliasing
250   /// value may not even be known at compile time. All such values are
251   /// considered to be aliases.
252   llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
253 
254   /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
255   /// buffer information is "must be" conservative: Only if two values are
256   /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
257   /// possible that, in the presence of branches, it cannot be determined
258   /// statically if two values are equivalent. In that case, the values are
259   /// considered to be not equivalent.
260   llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
261 
262   // Bufferization statistics.
263   int64_t statNumTensorOutOfPlace = 0;
264   int64_t statNumTensorInPlace = 0;
265 
266   /// A set of uses of tensors that have undefined contents.
267   DenseSet<OpOperand *> undefinedTensorUses;
268 
269   /// Extensions attached to the state, identified by the TypeID of their type.
270   /// Only one extension of any given type is allowed.
271   DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
272 };
273 
274 /// Analyze `op` and its nested ops. Bufferization decisions are stored in
275 /// `state`.
276 LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
277                         BufferizationStatistics *statistics = nullptr);
278 
279 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
280 LogicalResult
281 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
282                     BufferizationStatistics *statistics = nullptr);
283 
284 } // namespace bufferization
285 } // namespace mlir
286 
287 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
288 
289 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
290