1 //===- OneShotAnalysis.h - One-Shot (Single Pass) Analysis ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H 10 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H 11 12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 13 #include "llvm/ADT/EquivalenceClasses.h" 14 #include <string> 15 16 namespace mlir { 17 class DominanceInfo; 18 19 namespace bufferization { 20 21 struct OneShotBufferizationOptions; 22 struct BufferizationStatistics; 23 class OneShotAnalysisState; 24 25 /// Options for analysis-enabled bufferization. 26 struct OneShotBufferizationOptions : public BufferizationOptions { 27 enum class AnalysisHeuristic { 28 BottomUp, 29 TopDown, 30 BottomUpFromTerminators, 31 Fuzzer 32 }; 33 34 OneShotBufferizationOptions() = default; 35 36 /// Specifies whether returning newly allocated memrefs from loops should be 37 /// allowed. Otherwise, a pass failure is triggered. 38 bool allowReturnAllocsFromLoops = false; 39 40 /// Specifies whether the tensor IR should be annotated with alias sets. 41 bool dumpAliasSets = false; 42 43 /// The heuristic controls the order in which ops are traversed during the 44 /// analysis. 45 AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp; 46 47 /// Specify the functions that should not be analyzed. copyBeforeWrite will be 48 /// set to true when bufferizing them. 49 llvm::ArrayRef<std::string> noAnalysisFuncFilter; 50 51 /// Seed for the analysis fuzzer. Used only if the heuristic is set to 52 /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with 53 /// `testAnalysisOnly = true`. 54 unsigned analysisFuzzerSeed = 0; 55 }; 56 57 /// State for analysis-enabled bufferization. This class keeps track of alias 58 /// sets, equivalence sets, in-place OpOperands and other things. 59 /// 60 /// Note: Modifying the IR generally invalidates the result of the analysis. 61 /// Adding new operations is safe if they are analyzed subsequently. 62 class OneShotAnalysisState : public AnalysisState { 63 public: 64 OneShotAnalysisState(Operation *op, 65 const OneShotBufferizationOptions &options); 66 67 OneShotAnalysisState(const OneShotAnalysisState &) = delete; 68 69 ~OneShotAnalysisState() override = default; 70 71 static bool classof(const AnalysisState *base) { 72 return base->getType() == TypeID::get<OneShotAnalysisState>(); 73 } 74 75 /// Return a reference to the BufferizationOptions. 76 const OneShotBufferizationOptions &getOptions() const { 77 return static_cast<const OneShotBufferizationOptions &>( 78 AnalysisState::getOptions()); 79 } 80 81 /// Analyze the given op and its nested ops. 82 LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo); 83 84 /// Analyze a single op (without nested ops). 85 LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo); 86 87 /// Apply `fun` to all the members of the equivalence class of `v`. 88 void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const; 89 90 /// Apply `fun` to all aliases of `v`. 91 void applyOnAliases(Value v, function_ref<void(Value)> fun) const; 92 93 /// Return true if `v1` and `v2` bufferize to equivalent buffers. 94 bool areEquivalentBufferizedValues(Value v1, Value v2) const override; 95 96 /// Return true if `v1` and `v2` may bufferize to aliasing buffers. 97 bool areAliasingBufferizedValues(Value v1, Value v2) const override; 98 99 /// Mark the given OpOperand as in-place and merge the results' and operand's 100 /// aliasing sets. 101 void bufferizeInPlace(OpOperand &operand); 102 103 /// Mark the given OpOperand as out-of-place. 104 void bufferizeOutOfPlace(OpOperand &operand); 105 106 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the 107 /// beginning the alias and equivalence sets only contain `v` itself. 108 void createAliasInfoEntry(Value v); 109 110 /// Find all tensor values in the given operation that have undefined contents 111 /// and store them in `undefinedTensorUses`. 112 void gatherUndefinedTensorUses(Operation *op); 113 114 int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; } 115 int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; } 116 117 /// Return `true` if the given tensor has undefined contents. 118 bool hasUndefinedContents(OpOperand *opOperand) const override; 119 120 /// Return `true` if the given OpResult has been decided to bufferize inplace. 121 bool isInPlace(OpOperand &opOperand) const override; 122 123 /// Return true if the buffer of the given tensor value is written to. Must 124 /// not be called for values inside not yet analyzed functions. 125 bool isValueWritten(Value value) const; 126 127 /// Return true if the buffer of the given tensor value is writable. 128 bool isWritable(Value value) const; 129 130 /// Find the definitions of the given operand's value or 131 /// retrieve them from the cache. 132 const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand); 133 134 /// Reset cached data structures. 135 void resetCache() override; 136 137 /// Union the alias sets of `v1` and `v2`. 138 void unionAliasSets(Value v1, Value v2); 139 140 /// Union the equivalence classes of `v1` and `v2`. 141 void unionEquivalenceClasses(Value v1, Value v2); 142 143 /// Base class for OneShotAnalysisState extensions that allow 144 /// OneShotAnalysisState to contain user-specified information in the state 145 /// object. Clients are expected to derive this class, add the desired fields, 146 /// and make the derived class compatible with the MLIR TypeID mechanism. 147 /// 148 /// ```mlir 149 /// class MyExtension final : public OneShotAnalysisState::Extension { 150 /// public: 151 /// MyExtension(OneShotAnalysisState &state, int myData) 152 /// : Extension(state) {...} 153 /// private: 154 /// int mySupplementaryData; 155 /// }; 156 /// ``` 157 /// 158 /// Instances of this and derived classes are not expected to be created by 159 /// the user, instead they are directly constructed within a 160 /// OneShotAnalysisState. A OneShotAnalysisState can only contain one 161 /// extension with the given TypeID. Extensions can be obtained from a 162 /// OneShotAnalysisState instance. 163 /// 164 /// ```mlir 165 /// state.addExtension<MyExtension>(/*myData=*/42); 166 /// MyExtension *ext = state.getExtension<MyExtension>(); 167 /// ext->doSomething(); 168 /// ``` 169 class Extension { 170 // Allow OneShotAnalysisState to allocate Extensions. 171 friend class OneShotAnalysisState; 172 173 public: 174 /// Base virtual destructor. 175 // Out-of-line definition ensures symbols are emitted in a single object 176 // file. 177 virtual ~Extension(); 178 179 protected: 180 /// Constructs an extension of the given state object. 181 Extension(OneShotAnalysisState &state) : state(state) {} 182 183 /// Provides read-only access to the parent OneShotAnalysisState object. 184 const OneShotAnalysisState &getAnalysisState() const { return state; } 185 186 private: 187 /// Back-reference to the state that is being extended. 188 OneShotAnalysisState &state; 189 }; 190 191 /// Adds a new Extension of the type specified as template parameter, 192 /// constructing it with the arguments provided. The extension is owned by the 193 /// OneShotAnalysisState. It is expected that the state does not already have 194 /// an extension of the same type. Extension constructors are expected to take 195 /// a reference to OneShotAnalysisState as first argument, automatically 196 /// supplied by this call. 197 template <typename Ty, typename... Args> 198 Ty &addExtension(Args &&...args) { 199 static_assert( 200 std::is_base_of<Extension, Ty>::value, 201 "only a class derived from OneShotAnalysisState::Extension is allowed"); 202 auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...); 203 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr)); 204 assert(result.second && "extension already added"); 205 return *static_cast<Ty *>(result.first->second.get()); 206 } 207 208 /// Returns the extension of the specified type. 209 template <typename Ty> 210 Ty *getExtension() { 211 static_assert( 212 std::is_base_of<Extension, Ty>::value, 213 "only a class derived from OneShotAnalysisState::Extension is allowed"); 214 auto iter = extensions.find(TypeID::get<Ty>()); 215 if (iter == extensions.end()) 216 return nullptr; 217 return static_cast<Ty *>(iter->second.get()); 218 } 219 220 /// Returns the extension of the specified type. 221 template <typename Ty> 222 const Ty *getExtension() const { 223 return const_cast<OneShotAnalysisState *>(this)->getExtension<Ty>(); 224 } 225 226 private: 227 /// llvm::EquivalenceClasses wants comparable elements. This comparator uses 228 /// pointer comparison on the defining op. This is a poor man's comparison 229 /// but it's not like UnionFind needs ordering anyway. 230 struct ValueComparator { 231 bool operator()(const Value &lhs, const Value &rhs) const { 232 return lhs.getImpl() < rhs.getImpl(); 233 } 234 }; 235 236 using EquivalenceClassRangeType = llvm::iterator_range< 237 llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>; 238 /// Check that aliasInfo for `v` exists and return a reference to it. 239 EquivalenceClassRangeType getAliases(Value v) const; 240 241 /// Cache definitions of tensor values. 242 DenseMap<Value, SetVector<Value>> cachedDefinitions; 243 244 /// Set of all OpResults that were decided to bufferize in-place. 245 llvm::DenseSet<OpOperand *> inplaceBufferized; 246 247 /// Auxiliary structure to store all the values a given value may alias with. 248 /// Alias information is "may be" conservative: In the presence of branches, a 249 /// value may alias with one of multiple other values. The concrete aliasing 250 /// value may not even be known at compile time. All such values are 251 /// considered to be aliases. 252 llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo; 253 254 /// Auxiliary structure to store all the equivalent buffer classes. Equivalent 255 /// buffer information is "must be" conservative: Only if two values are 256 /// guaranteed to be equivalent at runtime, they said to be equivalent. It is 257 /// possible that, in the presence of branches, it cannot be determined 258 /// statically if two values are equivalent. In that case, the values are 259 /// considered to be not equivalent. 260 llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo; 261 262 // Bufferization statistics. 263 int64_t statNumTensorOutOfPlace = 0; 264 int64_t statNumTensorInPlace = 0; 265 266 /// A set of uses of tensors that have undefined contents. 267 DenseSet<OpOperand *> undefinedTensorUses; 268 269 /// Extensions attached to the state, identified by the TypeID of their type. 270 /// Only one extension of any given type is allowed. 271 DenseMap<TypeID, std::unique_ptr<Extension>> extensions; 272 }; 273 274 /// Analyze `op` and its nested ops. Bufferization decisions are stored in 275 /// `state`. 276 LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, 277 BufferizationStatistics *statistics = nullptr); 278 279 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization 280 LogicalResult 281 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, 282 BufferizationStatistics *statistics = nullptr); 283 284 } // namespace bufferization 285 } // namespace mlir 286 287 MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) 288 289 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H 290