1d07c90e3SRiver Riddle# Writing DataFlow Analyses in MLIR 2d07c90e3SRiver Riddle 3d07c90e3SRiver RiddleWriting dataflow analyses in MLIR, or well any compiler, can often seem quite 4d07c90e3SRiver Riddledaunting and/or complex. A dataflow analysis generally involves propagating 5d07c90e3SRiver Riddleinformation about the IR across various different types of control flow 6d07c90e3SRiver Riddleconstructs, of which MLIR has many (Block-based branches, Region-based branches, 7d07c90e3SRiver RiddleCallGraph, etc), and it isn't always clear how best to go about performing the 8d07c90e3SRiver Riddlepropagation. To help writing these types of analyses in MLIR, this document 9d07c90e3SRiver Riddledetails several utilities that simplify the process and make it a bit more 10d07c90e3SRiver Riddleapproachable. 11d07c90e3SRiver Riddle 12d07c90e3SRiver Riddle## Forward Dataflow Analysis 13d07c90e3SRiver Riddle 14d07c90e3SRiver RiddleOne type of dataflow analysis is a forward propagation analysis. This type of 15d07c90e3SRiver Riddleanalysis, as the name may suggest, propagates information forward (e.g. from 16d07c90e3SRiver Riddledefinitions to uses). To provide a bit of concrete context, let's go over 17d07c90e3SRiver Riddlewriting a simple forward dataflow analysis in MLIR. Let's say for this analysis 18d07c90e3SRiver Riddlethat we want to propagate information about a special "metadata" dictionary 19d07c90e3SRiver Riddleattribute. The contents of this attribute are simply a set of metadata that 20d07c90e3SRiver Riddledescribe a specific value, e.g. `metadata = { likes_pizza = true }`. We will 21d07c90e3SRiver Riddlecollect the `metadata` for operations in the IR and propagate them about. 22d07c90e3SRiver Riddle 23d07c90e3SRiver Riddle### Lattices 24d07c90e3SRiver Riddle 25d07c90e3SRiver RiddleBefore going into how one might setup the analysis itself, it is important to 26d07c90e3SRiver Riddlefirst introduce the concept of a `Lattice` and how we will use it for the 27d07c90e3SRiver Riddleanalysis. A lattice represents all of the possible values or results of the 28d07c90e3SRiver Riddleanalysis for a given value. A lattice element holds the set of information 29d07c90e3SRiver Riddlecomputed by the analysis for a given value, and is what gets propagated across 30d07c90e3SRiver Riddlethe IR. For our analysis, this would correspond to the `metadata` dictionary 31d07c90e3SRiver Riddleattribute. 32d07c90e3SRiver Riddle 33d07c90e3SRiver RiddleRegardless of the value held within, every type of lattice contains two special 34d07c90e3SRiver Riddleelement states: 35d07c90e3SRiver Riddle 36d07c90e3SRiver Riddle* `uninitialized` 37d07c90e3SRiver Riddle 38d07c90e3SRiver Riddle - The element has not been initialized. 39d07c90e3SRiver Riddle 40d07c90e3SRiver Riddle* `top`/`overdefined`/`unknown` 41d07c90e3SRiver Riddle 42d07c90e3SRiver Riddle - The element encompasses every possible value. 43d07c90e3SRiver Riddle - This is a very conservative state, and essentially means "I can't make 44d07c90e3SRiver Riddle any assumptions about the value, it could be anything" 45d07c90e3SRiver Riddle 46d07c90e3SRiver RiddleThese two states are important when merging, or `join`ing as we will refer to it 47d07c90e3SRiver Riddlefurther in this document, information as part of the analysis. Lattice elements 48d07c90e3SRiver Riddleare `join`ed whenever there are two different source points, such as an argument 49d07c90e3SRiver Riddleto a block with multiple predecessors. One important note about the `join` 50d07c90e3SRiver Riddleoperation, is that it is required to be monotonic (see the `join` method in the 51d07c90e3SRiver Riddleexample below for more information). This ensures that `join`ing elements is 52d07c90e3SRiver Riddleconsistent. The two special states mentioned above have unique properties during 53d07c90e3SRiver Riddlea `join`: 54d07c90e3SRiver Riddle 55d07c90e3SRiver Riddle* `uninitialized` 56d07c90e3SRiver Riddle 57d07c90e3SRiver Riddle - If one of the elements is `uninitialized`, the other element is used. 58d07c90e3SRiver Riddle - `uninitialized` in the context of a `join` essentially means "take the 59d07c90e3SRiver Riddle other thing". 60d07c90e3SRiver Riddle 61d07c90e3SRiver Riddle* `top`/`overdefined`/`unknown` 62d07c90e3SRiver Riddle 63d07c90e3SRiver Riddle - If one of the elements being joined is `overdefined`, the result is 64d07c90e3SRiver Riddle `overdefined`. 65d07c90e3SRiver Riddle 66d07c90e3SRiver RiddleFor our analysis in MLIR, we will need to define a class representing the value 67d07c90e3SRiver Riddleheld by an element of the lattice used by our dataflow analysis: 68d07c90e3SRiver Riddle 69d07c90e3SRiver Riddle```c++ 70d07c90e3SRiver Riddle/// The value of our lattice represents the inner structure of a DictionaryAttr, 71d07c90e3SRiver Riddle/// for the `metadata`. 72d07c90e3SRiver Riddlestruct MetadataLatticeValue { 73d07c90e3SRiver Riddle MetadataLatticeValue() = default; 74d07c90e3SRiver Riddle /// Compute a lattice value from the provided dictionary. 75d07c90e3SRiver Riddle MetadataLatticeValue(DictionaryAttr attr) 76d07c90e3SRiver Riddle : metadata(attr.begin(), attr.end()) {} 77d07c90e3SRiver Riddle 78d07c90e3SRiver Riddle /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown` 79d07c90e3SRiver Riddle /// state, for our value type. The resultant state should not assume any 80d07c90e3SRiver Riddle /// information about the state of the IR. 81d07c90e3SRiver Riddle static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) { 82d07c90e3SRiver Riddle // The `top`/`overdefined`/`unknown` state is when we know nothing about any 83d07c90e3SRiver Riddle // metadata, i.e. an empty dictionary. 84d07c90e3SRiver Riddle return MetadataLatticeValue(); 85d07c90e3SRiver Riddle } 86d07c90e3SRiver Riddle /// Return a pessimistic value state for our value type using only information 87d07c90e3SRiver Riddle /// about the state of the provided IR. This is similar to the above method, 88d07c90e3SRiver Riddle /// but may produce a slightly more refined result. This is okay, as the 89d07c90e3SRiver Riddle /// information is already encoded as fact in the IR. 90d07c90e3SRiver Riddle static MetadataLatticeValue getPessimisticValueState(Value value) { 91d07c90e3SRiver Riddle // Check to see if the parent operation has metadata. 92d07c90e3SRiver Riddle if (Operation *parentOp = value.getDefiningOp()) { 93d07c90e3SRiver Riddle if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata")) 94d07c90e3SRiver Riddle return MetadataLatticeValue(metadata); 95d07c90e3SRiver Riddle 96d07c90e3SRiver Riddle // If no metadata is present, fallback to the 97d07c90e3SRiver Riddle // `top`/`overdefined`/`unknown` state. 98d07c90e3SRiver Riddle } 99d07c90e3SRiver Riddle return MetadataLatticeValue(); 100d07c90e3SRiver Riddle } 101d07c90e3SRiver Riddle 102d07c90e3SRiver Riddle /// This method conservatively joins the information held by `lhs` and `rhs` 103d07c90e3SRiver Riddle /// into a new value. This method is required to be monotonic. `monotonicity` 104d07c90e3SRiver Riddle /// is implied by the satisfaction of the following axioms: 105d07c90e3SRiver Riddle /// * idempotence: join(x,x) == x 106d07c90e3SRiver Riddle /// * commutativity: join(x,y) == join(y,x) 107d07c90e3SRiver Riddle /// * associativity: join(x,join(y,z)) == join(join(x,y),z) 108d07c90e3SRiver Riddle /// 109d07c90e3SRiver Riddle /// When the above axioms are satisfied, we achieve `monotonicity`: 110d07c90e3SRiver Riddle /// * monotonicity: join(x, join(x,y)) == join(x,y) 111d07c90e3SRiver Riddle static MetadataLatticeValue join(const MetadataLatticeValue &lhs, 112d07c90e3SRiver Riddle const MetadataLatticeValue &rhs) { 113d07c90e3SRiver Riddle // To join `lhs` and `rhs` we will define a simple policy, which is that we 114d07c90e3SRiver Riddle // only keep information that is the same. This means that we only keep 115d07c90e3SRiver Riddle // facts that are true in both. 116d07c90e3SRiver Riddle MetadataLatticeValue result; 117*becc238fSMatthias Springer for (const auto &lhsIt : lhs.metadata) { 118d07c90e3SRiver Riddle // As noted above, we only merge if the values are the same. 119d07c90e3SRiver Riddle auto it = rhs.metadata.find(lhsIt.first); 120*becc238fSMatthias Springer if (it == rhs.metadata.end() || it.second != lhsIt.second) 121d07c90e3SRiver Riddle continue; 122d07c90e3SRiver Riddle result.insert(lhsIt); 123d07c90e3SRiver Riddle } 124d07c90e3SRiver Riddle return result; 125d07c90e3SRiver Riddle } 126d07c90e3SRiver Riddle 127d07c90e3SRiver Riddle /// A simple comparator that checks to see if this value is equal to the one 128d07c90e3SRiver Riddle /// provided. 129d07c90e3SRiver Riddle bool operator==(const MetadataLatticeValue &rhs) const { 130d07c90e3SRiver Riddle if (metadata.size() != rhs.metadata.size()) 131d07c90e3SRiver Riddle return false; 132*becc238fSMatthias Springer // Check that `rhs` contains the same metadata. 133*becc238fSMatthias Springer for (const auto &it : metadata) { 134*becc238fSMatthias Springer auto rhsIt = rhs.metadata.find(it.first); 135*becc238fSMatthias Springer if (rhsIt == rhs.metadata.end() || it.second != rhsIt.second) 136*becc238fSMatthias Springer return false; 137*becc238fSMatthias Springer } 138*becc238fSMatthias Springer return true; 139d07c90e3SRiver Riddle } 140d07c90e3SRiver Riddle 141d07c90e3SRiver Riddle /// Our value represents the combined metadata, which is originally a 142d07c90e3SRiver Riddle /// DictionaryAttr, so we use a map. 143195730a6SRiver Riddle DenseMap<StringAttr, Attribute> metadata; 144d07c90e3SRiver Riddle}; 145d07c90e3SRiver Riddle``` 146d07c90e3SRiver Riddle 147d07c90e3SRiver RiddleOne interesting thing to note above is that we don't have an explicit method for 148d07c90e3SRiver Riddlethe `uninitialized` state. This state is handled by the `LatticeElement` class, 149d07c90e3SRiver Riddlewhich manages a lattice value for a given IR entity. A quick overview of this 150d07c90e3SRiver Riddleclass, and the API that will be interesting to us while writing our analysis, is 151d07c90e3SRiver Riddleshown below: 152d07c90e3SRiver Riddle 153d07c90e3SRiver Riddle```c++ 154d07c90e3SRiver Riddle/// This class represents a lattice element holding a specific value of type 155d07c90e3SRiver Riddle/// `ValueT`. 156d07c90e3SRiver Riddletemplate <typename ValueT> 157d07c90e3SRiver Riddleclass LatticeElement ... { 158d07c90e3SRiver Riddlepublic: 159d07c90e3SRiver Riddle /// Return the value held by this element. This requires that a value is 160d07c90e3SRiver Riddle /// known, i.e. not `uninitialized`. 161d07c90e3SRiver Riddle ValueT &getValue(); 162d07c90e3SRiver Riddle const ValueT &getValue() const; 163d07c90e3SRiver Riddle 164d07c90e3SRiver Riddle /// Join the information contained in the 'rhs' element into this 165d07c90e3SRiver Riddle /// element. Returns if the state of the current element changed. 166d07c90e3SRiver Riddle ChangeResult join(const LatticeElement<ValueT> &rhs); 167d07c90e3SRiver Riddle 168d07c90e3SRiver Riddle /// Join the information contained in the 'rhs' value into this 169d07c90e3SRiver Riddle /// lattice. Returns if the state of the current lattice changed. 170d07c90e3SRiver Riddle ChangeResult join(const ValueT &rhs); 171d07c90e3SRiver Riddle 172d07c90e3SRiver Riddle /// Mark the lattice element as having reached a pessimistic fixpoint. This 173d07c90e3SRiver Riddle /// means that the lattice may potentially have conflicting value states, and 174d07c90e3SRiver Riddle /// only the conservatively known value state should be relied on. 175d07c90e3SRiver Riddle ChangeResult markPessimisticFixPoint(); 176d07c90e3SRiver Riddle}; 177d07c90e3SRiver Riddle``` 178d07c90e3SRiver Riddle 179d07c90e3SRiver RiddleWith our lattice defined, we can now define the driver that will compute and 180d07c90e3SRiver Riddlepropagate our lattice across the IR. 181d07c90e3SRiver Riddle 182d07c90e3SRiver Riddle### ForwardDataflowAnalysis Driver 183d07c90e3SRiver Riddle 184d07c90e3SRiver RiddleThe `ForwardDataFlowAnalysis` class represents the driver of the dataflow 185d07c90e3SRiver Riddleanalysis, and performs all of the related analysis computation. When defining 186d07c90e3SRiver Riddleour analysis, we will inherit from this class and implement some of its hooks. 187d07c90e3SRiver RiddleBefore that, let's look at a quick overview of this class and some of the 188d07c90e3SRiver Riddleimportant API for our analysis: 189d07c90e3SRiver Riddle 190d07c90e3SRiver Riddle```c++ 191d07c90e3SRiver Riddle/// This class represents the main driver of the forward dataflow analysis. It 192d07c90e3SRiver Riddle/// takes as a template parameter the value type of lattice being computed. 193d07c90e3SRiver Riddletemplate <typename ValueT> 194d07c90e3SRiver Riddleclass ForwardDataFlowAnalysis : ... { 195d07c90e3SRiver Riddlepublic: 196d07c90e3SRiver Riddle ForwardDataFlowAnalysis(MLIRContext *context); 197d07c90e3SRiver Riddle 198d07c90e3SRiver Riddle /// Compute the analysis on operations rooted under the given top-level 199d07c90e3SRiver Riddle /// operation. Note that the top-level operation is not visited. 200d07c90e3SRiver Riddle void run(Operation *topLevelOp); 201d07c90e3SRiver Riddle 202d07c90e3SRiver Riddle /// Return the lattice element attached to the given value. If a lattice has 203d07c90e3SRiver Riddle /// not been added for the given value, a new 'uninitialized' value is 204d07c90e3SRiver Riddle /// inserted and returned. 205d07c90e3SRiver Riddle LatticeElement<ValueT> &getLatticeElement(Value value); 206d07c90e3SRiver Riddle 207d07c90e3SRiver Riddle /// Return the lattice element attached to the given value, or nullptr if no 208d07c90e3SRiver Riddle /// lattice element for the value has yet been created. 209d07c90e3SRiver Riddle LatticeElement<ValueT> *lookupLatticeElement(Value value); 210d07c90e3SRiver Riddle 211d07c90e3SRiver Riddle /// Mark all of the lattice elements for the given range of Values as having 212d07c90e3SRiver Riddle /// reached a pessimistic fixpoint. 213d07c90e3SRiver Riddle ChangeResult markAllPessimisticFixPoint(ValueRange values); 214d07c90e3SRiver Riddle 215d07c90e3SRiver Riddleprotected: 216d07c90e3SRiver Riddle /// Visit the given operation, and join any necessary analysis state 217d07c90e3SRiver Riddle /// into the lattice elements for the results and block arguments owned by 218d07c90e3SRiver Riddle /// this operation using the provided set of operand lattice elements 219d07c90e3SRiver Riddle /// (all pointer values are guaranteed to be non-null). Returns if any result 220d07c90e3SRiver Riddle /// or block argument value lattice elements changed during the visit. The 221d07c90e3SRiver Riddle /// lattice element for a result or block argument value can be obtained, and 222d07c90e3SRiver Riddle /// join'ed into, by using `getLatticeElement`. 223d07c90e3SRiver Riddle virtual ChangeResult visitOperation( 224d07c90e3SRiver Riddle Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0; 225d07c90e3SRiver Riddle}; 226d07c90e3SRiver Riddle``` 227d07c90e3SRiver Riddle 228d07c90e3SRiver RiddleNOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis` 229d07c90e3SRiver Riddlecontains various other hooks that allow for injecting custom behavior when 230d07c90e3SRiver Riddleapplicable. 231d07c90e3SRiver Riddle 232d07c90e3SRiver RiddleThe main API that we are responsible for defining is the `visitOperation` 233d07c90e3SRiver Riddlemethod. This method is responsible for computing new lattice elements for the 234d07c90e3SRiver Riddleresults and block arguments owned by the given operation. This is where we will 235d07c90e3SRiver Riddleinject the lattice element computation logic, also known as the transfer 236d07c90e3SRiver Riddlefunction for the operation, that is specific to our analysis. A simple 237d07c90e3SRiver Riddleimplementation for our example is shown below: 238d07c90e3SRiver Riddle 239d07c90e3SRiver Riddle```c++ 240d07c90e3SRiver Riddleclass MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> { 241d07c90e3SRiver Riddlepublic: 242d07c90e3SRiver Riddle using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis; 243d07c90e3SRiver Riddle 244d07c90e3SRiver Riddle ChangeResult visitOperation( 245d07c90e3SRiver Riddle Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override { 246d07c90e3SRiver Riddle DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata"); 247d07c90e3SRiver Riddle 248d07c90e3SRiver Riddle // If we have no metadata for this operation, we will conservatively mark 249d07c90e3SRiver Riddle // all of the results as having reached a pessimistic fixpoint. 250d07c90e3SRiver Riddle if (!metadata) 251d07c90e3SRiver Riddle return markAllPessimisticFixPoint(op->getResults()); 252d07c90e3SRiver Riddle 253d07c90e3SRiver Riddle // Otherwise, we will compute a lattice value for the metadata and join it 254d07c90e3SRiver Riddle // into the current lattice element for all of our results. 255d07c90e3SRiver Riddle MetadataLatticeValue latticeValue(metadata); 256d07c90e3SRiver Riddle ChangeResult result = ChangeResult::NoChange; 257d07c90e3SRiver Riddle for (Value value : op->getResults()) { 258d07c90e3SRiver Riddle // We grab the lattice element for `value` via `getLatticeElement` and 259d07c90e3SRiver Riddle // then join it with the lattice value for this operation's metadata. Note 260d07c90e3SRiver Riddle // that during the analysis phase, it is fine to freely create a new 261d07c90e3SRiver Riddle // lattice element for a value. This is why we don't use the 262d07c90e3SRiver Riddle // `lookupLatticeElement` method here. 263d07c90e3SRiver Riddle result |= getLatticeElement(value).join(latticeValue); 264d07c90e3SRiver Riddle } 265d07c90e3SRiver Riddle return result; 266d07c90e3SRiver Riddle } 267d07c90e3SRiver Riddle}; 268d07c90e3SRiver Riddle``` 269d07c90e3SRiver Riddle 270d07c90e3SRiver RiddleWith that, we have all of the necessary components to compute our analysis. 271d07c90e3SRiver RiddleAfter the analysis has been computed, we can grab any computed information for 272d07c90e3SRiver Riddlevalues by using `lookupLatticeElement`. We use this function over 273d07c90e3SRiver Riddle`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g. 274d07c90e3SRiver Riddleif the value is in a unreachable block, and we don't want to create a new 275d07c90e3SRiver Riddleuninitialized lattice element in this case. See below for a quick example: 276d07c90e3SRiver Riddle 277d07c90e3SRiver Riddle```c++ 278d07c90e3SRiver Riddlevoid MyPass::runOnOperation() { 279d07c90e3SRiver Riddle MetadataAnalysis analysis(&getContext()); 280d07c90e3SRiver Riddle analysis.run(getOperation()); 281d07c90e3SRiver Riddle ... 282d07c90e3SRiver Riddle} 283d07c90e3SRiver Riddle 284d07c90e3SRiver Riddlevoid MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) { 285d07c90e3SRiver Riddle LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value); 286d07c90e3SRiver Riddle 287d07c90e3SRiver Riddle // If we don't have an element, the `value` wasn't visited during our analysis 288d07c90e3SRiver Riddle // meaning that it could be dead. We need to treat this conservatively. 289d07c90e3SRiver Riddle if (!lattice) 290d07c90e3SRiver Riddle return; 291d07c90e3SRiver Riddle 292d07c90e3SRiver Riddle // Our lattice element has a value, use it: 293d07c90e3SRiver Riddle MetadataLatticeValue &value = lattice->getValue(); 294d07c90e3SRiver Riddle ... 295d07c90e3SRiver Riddle} 296d07c90e3SRiver Riddle``` 297