1# Writing DataFlow Analyses in MLIR 2 3Writing dataflow analyses in MLIR, or well any compiler, can often seem quite 4daunting and/or complex. A dataflow analysis generally involves propagating 5information about the IR across various different types of control flow 6constructs, of which MLIR has many (Block-based branches, Region-based branches, 7CallGraph, etc), and it isn't always clear how best to go about performing the 8propagation. To help writing these types of analyses in MLIR, this document 9details several utilities that simplify the process and make it a bit more 10approachable. 11 12## Forward Dataflow Analysis 13 14One type of dataflow analysis is a forward propagation analysis. This type of 15analysis, as the name may suggest, propagates information forward (e.g. from 16definitions to uses). To provide a bit of concrete context, let's go over 17writing a simple forward dataflow analysis in MLIR. Let's say for this analysis 18that we want to propagate information about a special "metadata" dictionary 19attribute. The contents of this attribute are simply a set of metadata that 20describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will 21collect the `metadata` for operations in the IR and propagate them about. 22 23### Lattices 24 25Before going into how one might setup the analysis itself, it is important to 26first introduce the concept of a `Lattice` and how we will use it for the 27analysis. A lattice represents all of the possible values or results of the 28analysis for a given value. A lattice element holds the set of information 29computed by the analysis for a given value, and is what gets propagated across 30the IR. For our analysis, this would correspond to the `metadata` dictionary 31attribute. 32 33Regardless of the value held within, every type of lattice contains two special 34element states: 35 36* `uninitialized` 37 38 - The element has not been initialized. 39 40* `top`/`overdefined`/`unknown` 41 42 - The element encompasses every possible value. 43 - This is a very conservative state, and essentially means "I can't make 44 any assumptions about the value, it could be anything" 45 46These two states are important when merging, or `join`ing as we will refer to it 47further in this document, information as part of the analysis. Lattice elements 48are `join`ed whenever there are two different source points, such as an argument 49to a block with multiple predecessors. One important note about the `join` 50operation, is that it is required to be monotonic (see the `join` method in the 51example below for more information). This ensures that `join`ing elements is 52consistent. The two special states mentioned above have unique properties during 53a `join`: 54 55* `uninitialized` 56 57 - If one of the elements is `uninitialized`, the other element is used. 58 - `uninitialized` in the context of a `join` essentially means "take the 59 other thing". 60 61* `top`/`overdefined`/`unknown` 62 63 - If one of the elements being joined is `overdefined`, the result is 64 `overdefined`. 65 66For our analysis in MLIR, we will need to define a class representing the value 67held by an element of the lattice used by our dataflow analysis: 68 69```c++ 70/// The value of our lattice represents the inner structure of a DictionaryAttr, 71/// for the `metadata`. 72struct MetadataLatticeValue { 73 MetadataLatticeValue() = default; 74 /// Compute a lattice value from the provided dictionary. 75 MetadataLatticeValue(DictionaryAttr attr) 76 : metadata(attr.begin(), attr.end()) {} 77 78 /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown` 79 /// state, for our value type. The resultant state should not assume any 80 /// information about the state of the IR. 81 static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) { 82 // The `top`/`overdefined`/`unknown` state is when we know nothing about any 83 // metadata, i.e. an empty dictionary. 84 return MetadataLatticeValue(); 85 } 86 /// Return a pessimistic value state for our value type using only information 87 /// about the state of the provided IR. This is similar to the above method, 88 /// but may produce a slightly more refined result. This is okay, as the 89 /// information is already encoded as fact in the IR. 90 static MetadataLatticeValue getPessimisticValueState(Value value) { 91 // Check to see if the parent operation has metadata. 92 if (Operation *parentOp = value.getDefiningOp()) { 93 if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata")) 94 return MetadataLatticeValue(metadata); 95 96 // If no metadata is present, fallback to the 97 // `top`/`overdefined`/`unknown` state. 98 } 99 return MetadataLatticeValue(); 100 } 101 102 /// This method conservatively joins the information held by `lhs` and `rhs` 103 /// into a new value. This method is required to be monotonic. `monotonicity` 104 /// is implied by the satisfaction of the following axioms: 105 /// * idempotence: join(x,x) == x 106 /// * commutativity: join(x,y) == join(y,x) 107 /// * associativity: join(x,join(y,z)) == join(join(x,y),z) 108 /// 109 /// When the above axioms are satisfied, we achieve `monotonicity`: 110 /// * monotonicity: join(x, join(x,y)) == join(x,y) 111 static MetadataLatticeValue join(const MetadataLatticeValue &lhs, 112 const MetadataLatticeValue &rhs) { 113 // To join `lhs` and `rhs` we will define a simple policy, which is that we 114 // only keep information that is the same. This means that we only keep 115 // facts that are true in both. 116 MetadataLatticeValue result; 117 for (const auto &lhsIt : lhs.metadata) { 118 // As noted above, we only merge if the values are the same. 119 auto it = rhs.metadata.find(lhsIt.first); 120 if (it == rhs.metadata.end() || it.second != lhsIt.second) 121 continue; 122 result.insert(lhsIt); 123 } 124 return result; 125 } 126 127 /// A simple comparator that checks to see if this value is equal to the one 128 /// provided. 129 bool operator==(const MetadataLatticeValue &rhs) const { 130 if (metadata.size() != rhs.metadata.size()) 131 return false; 132 // Check that `rhs` contains the same metadata. 133 for (const auto &it : metadata) { 134 auto rhsIt = rhs.metadata.find(it.first); 135 if (rhsIt == rhs.metadata.end() || it.second != rhsIt.second) 136 return false; 137 } 138 return true; 139 } 140 141 /// Our value represents the combined metadata, which is originally a 142 /// DictionaryAttr, so we use a map. 143 DenseMap<StringAttr, Attribute> metadata; 144}; 145``` 146 147One interesting thing to note above is that we don't have an explicit method for 148the `uninitialized` state. This state is handled by the `LatticeElement` class, 149which manages a lattice value for a given IR entity. A quick overview of this 150class, and the API that will be interesting to us while writing our analysis, is 151shown below: 152 153```c++ 154/// This class represents a lattice element holding a specific value of type 155/// `ValueT`. 156template <typename ValueT> 157class LatticeElement ... { 158public: 159 /// Return the value held by this element. This requires that a value is 160 /// known, i.e. not `uninitialized`. 161 ValueT &getValue(); 162 const ValueT &getValue() const; 163 164 /// Join the information contained in the 'rhs' element into this 165 /// element. Returns if the state of the current element changed. 166 ChangeResult join(const LatticeElement<ValueT> &rhs); 167 168 /// Join the information contained in the 'rhs' value into this 169 /// lattice. Returns if the state of the current lattice changed. 170 ChangeResult join(const ValueT &rhs); 171 172 /// Mark the lattice element as having reached a pessimistic fixpoint. This 173 /// means that the lattice may potentially have conflicting value states, and 174 /// only the conservatively known value state should be relied on. 175 ChangeResult markPessimisticFixPoint(); 176}; 177``` 178 179With our lattice defined, we can now define the driver that will compute and 180propagate our lattice across the IR. 181 182### ForwardDataflowAnalysis Driver 183 184The `ForwardDataFlowAnalysis` class represents the driver of the dataflow 185analysis, and performs all of the related analysis computation. When defining 186our analysis, we will inherit from this class and implement some of its hooks. 187Before that, let's look at a quick overview of this class and some of the 188important API for our analysis: 189 190```c++ 191/// This class represents the main driver of the forward dataflow analysis. It 192/// takes as a template parameter the value type of lattice being computed. 193template <typename ValueT> 194class ForwardDataFlowAnalysis : ... { 195public: 196 ForwardDataFlowAnalysis(MLIRContext *context); 197 198 /// Compute the analysis on operations rooted under the given top-level 199 /// operation. Note that the top-level operation is not visited. 200 void run(Operation *topLevelOp); 201 202 /// Return the lattice element attached to the given value. If a lattice has 203 /// not been added for the given value, a new 'uninitialized' value is 204 /// inserted and returned. 205 LatticeElement<ValueT> &getLatticeElement(Value value); 206 207 /// Return the lattice element attached to the given value, or nullptr if no 208 /// lattice element for the value has yet been created. 209 LatticeElement<ValueT> *lookupLatticeElement(Value value); 210 211 /// Mark all of the lattice elements for the given range of Values as having 212 /// reached a pessimistic fixpoint. 213 ChangeResult markAllPessimisticFixPoint(ValueRange values); 214 215protected: 216 /// Visit the given operation, and join any necessary analysis state 217 /// into the lattice elements for the results and block arguments owned by 218 /// this operation using the provided set of operand lattice elements 219 /// (all pointer values are guaranteed to be non-null). Returns if any result 220 /// or block argument value lattice elements changed during the visit. The 221 /// lattice element for a result or block argument value can be obtained, and 222 /// join'ed into, by using `getLatticeElement`. 223 virtual ChangeResult visitOperation( 224 Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0; 225}; 226``` 227 228NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis` 229contains various other hooks that allow for injecting custom behavior when 230applicable. 231 232The main API that we are responsible for defining is the `visitOperation` 233method. This method is responsible for computing new lattice elements for the 234results and block arguments owned by the given operation. This is where we will 235inject the lattice element computation logic, also known as the transfer 236function for the operation, that is specific to our analysis. A simple 237implementation for our example is shown below: 238 239```c++ 240class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> { 241public: 242 using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis; 243 244 ChangeResult visitOperation( 245 Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override { 246 DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata"); 247 248 // If we have no metadata for this operation, we will conservatively mark 249 // all of the results as having reached a pessimistic fixpoint. 250 if (!metadata) 251 return markAllPessimisticFixPoint(op->getResults()); 252 253 // Otherwise, we will compute a lattice value for the metadata and join it 254 // into the current lattice element for all of our results. 255 MetadataLatticeValue latticeValue(metadata); 256 ChangeResult result = ChangeResult::NoChange; 257 for (Value value : op->getResults()) { 258 // We grab the lattice element for `value` via `getLatticeElement` and 259 // then join it with the lattice value for this operation's metadata. Note 260 // that during the analysis phase, it is fine to freely create a new 261 // lattice element for a value. This is why we don't use the 262 // `lookupLatticeElement` method here. 263 result |= getLatticeElement(value).join(latticeValue); 264 } 265 return result; 266 } 267}; 268``` 269 270With that, we have all of the necessary components to compute our analysis. 271After the analysis has been computed, we can grab any computed information for 272values by using `lookupLatticeElement`. We use this function over 273`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g. 274if the value is in a unreachable block, and we don't want to create a new 275uninitialized lattice element in this case. See below for a quick example: 276 277```c++ 278void MyPass::runOnOperation() { 279 MetadataAnalysis analysis(&getContext()); 280 analysis.run(getOperation()); 281 ... 282} 283 284void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) { 285 LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value); 286 287 // If we don't have an element, the `value` wasn't visited during our analysis 288 // meaning that it could be dead. We need to treat this conservatively. 289 if (!lattice) 290 return; 291 292 // Our lattice element has a value, use it: 293 MetadataLatticeValue &value = lattice->getValue(); 294 ... 295} 296``` 297