xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (revision b6603e1bf11dee4761e49af6581c8b8f074b705d)
1 //===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlowFramework.h"
11 #include "mlir/IR/Value.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/Support/raw_ostream.h"
14 #include <optional>
15 
16 namespace mlir {
17 namespace dataflow {
18 namespace test {
19 
20 /// This lattice represents a single underlying value for an SSA value.
21 class UnderlyingValue {
22 public:
23   /// Create an underlying value state with a known underlying value.
24   explicit UnderlyingValue(std::optional<Value> underlyingValue = std::nullopt)
25       : underlyingValue(underlyingValue) {}
26 
27   /// Whether the state is uninitialized.
28   bool isUninitialized() const { return !underlyingValue.has_value(); }
29 
30   /// Returns the underlying value.
31   Value getUnderlyingValue() const {
32     assert(!isUninitialized());
33     return *underlyingValue;
34   }
35 
36   /// Join two underlying values. If there are conflicting underlying values,
37   /// go to the pessimistic value.
38   static UnderlyingValue join(const UnderlyingValue &lhs,
39                               const UnderlyingValue &rhs) {
40     if (lhs.isUninitialized())
41       return rhs;
42     if (rhs.isUninitialized())
43       return lhs;
44     return lhs.underlyingValue == rhs.underlyingValue
45                ? lhs
46                : UnderlyingValue(Value{});
47   }
48 
49   /// Compare underlying values.
50   bool operator==(const UnderlyingValue &rhs) const {
51     return underlyingValue == rhs.underlyingValue;
52   }
53 
54   void print(raw_ostream &os) const { os << underlyingValue; }
55 
56 private:
57   std::optional<Value> underlyingValue;
58 };
59 
60 class AdjacentAccess {
61 public:
62   using DeterministicSetVector =
63       SetVector<Operation *, SmallVector<Operation *, 2>,
64                 SmallPtrSet<Operation *, 2>>;
65 
66   ArrayRef<Operation *> get() const { return accesses.getArrayRef(); }
67   bool isKnown() const { return !unknown; }
68 
69   ChangeResult merge(const AdjacentAccess &other) {
70     if (unknown)
71       return ChangeResult::NoChange;
72     if (other.unknown) {
73       unknown = true;
74       accesses.clear();
75       return ChangeResult::Change;
76     }
77 
78     size_t sizeBefore = accesses.size();
79     accesses.insert(other.accesses.begin(), other.accesses.end());
80     return accesses.size() == sizeBefore ? ChangeResult::NoChange
81                                          : ChangeResult::Change;
82   }
83 
84   ChangeResult set(Operation *op) {
85     if (!unknown && accesses.size() == 1 && *accesses.begin() == op)
86       return ChangeResult::NoChange;
87 
88     unknown = false;
89     accesses.clear();
90     accesses.insert(op);
91     return ChangeResult::Change;
92   }
93 
94   ChangeResult setUnknown() {
95     if (unknown)
96       return ChangeResult::NoChange;
97 
98     accesses.clear();
99     unknown = true;
100     return ChangeResult::Change;
101   }
102 
103   bool operator==(const AdjacentAccess &other) const {
104     return unknown == other.unknown && accesses == other.accesses;
105   }
106 
107   bool operator!=(const AdjacentAccess &other) const {
108     return !operator==(other);
109   }
110 
111 private:
112   bool unknown = false;
113   DeterministicSetVector accesses;
114 };
115 
116 /// This lattice represents, for a given memory resource, the potential last
117 /// operations that modified the resource.
118 class AccessLatticeBase {
119 public:
120   /// Clear all modifications.
121   ChangeResult reset() {
122     if (adjAccesses.empty())
123       return ChangeResult::NoChange;
124     adjAccesses.clear();
125     return ChangeResult::Change;
126   }
127 
128   /// Join the last modifications.
129   ChangeResult merge(const AccessLatticeBase &rhs) {
130     ChangeResult result = ChangeResult::NoChange;
131     for (const auto &mod : rhs.adjAccesses) {
132       AdjacentAccess &lhsMod = adjAccesses[mod.first];
133       result |= lhsMod.merge(mod.second);
134     }
135     return result;
136   }
137 
138   /// Set the last modification of a value.
139   ChangeResult set(Value value, Operation *op) {
140     AdjacentAccess &lastMod = adjAccesses[value];
141     return lastMod.set(op);
142   }
143 
144   ChangeResult setKnownToUnknown() {
145     ChangeResult result = ChangeResult::NoChange;
146     for (auto &[value, adjacent] : adjAccesses)
147       result |= adjacent.setUnknown();
148     return result;
149   }
150 
151   /// Get the adjacent accesses to a value. Returns std::nullopt if they
152   /// are not known.
153   const AdjacentAccess *getAdjacentAccess(Value value) const {
154     auto it = adjAccesses.find(value);
155     if (it == adjAccesses.end())
156       return nullptr;
157     return &it->getSecond();
158   }
159 
160   void print(raw_ostream &os) const {
161     for (const auto &lastMod : adjAccesses) {
162       os << lastMod.first << ":\n";
163       if (!lastMod.second.isKnown()) {
164         os << "  <unknown>\n";
165         return;
166       }
167       for (Operation *op : lastMod.second.get())
168         os << "  " << *op << "\n";
169     }
170   }
171 
172 private:
173   /// The potential adjacent accesses to a memory resource. Use a set vector to
174   /// keep the results deterministic.
175   DenseMap<Value, AdjacentAccess> adjAccesses;
176 };
177 
178 /// Define the lattice class explicitly to provide a type ID.
179 struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
180   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
181   using Lattice::Lattice;
182 };
183 
184 /// An analysis that uses forwarding of values along control-flow and callgraph
185 /// edges to determine single underlying values for block arguments. This
186 /// analysis exists so that the test analysis and pass can test the behaviour of
187 /// the dense data-flow analysis on the callgraph.
188 class UnderlyingValueAnalysis
189     : public SparseForwardDataFlowAnalysis<UnderlyingValueLattice> {
190 public:
191   using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
192 
193   /// The underlying value of the results of an operation are not known.
194   LogicalResult
195   visitOperation(Operation *op,
196                  ArrayRef<const UnderlyingValueLattice *> operands,
197                  ArrayRef<UnderlyingValueLattice *> results) override {
198     // Hook to test error propagation from visitOperation.
199     if (op->hasAttr("always_fail"))
200       return op->emitError("this op is always fails");
201 
202     setAllToEntryStates(results);
203     return success();
204   }
205 
206   /// At an entry point, the underlying value of a value is itself.
207   void setToEntryState(UnderlyingValueLattice *lattice) override {
208     propagateIfChanged(lattice,
209                        lattice->join(UnderlyingValue{lattice->getAnchor()}));
210   }
211 
212   /// Look for the most underlying value of a value.
213   static std::optional<Value>
214   getMostUnderlyingValue(Value value,
215                          function_ref<const UnderlyingValueLattice *(Value)>
216                              getUnderlyingValueFn) {
217     const UnderlyingValueLattice *underlying;
218     do {
219       underlying = getUnderlyingValueFn(value);
220       if (!underlying || underlying->getValue().isUninitialized())
221         return std::nullopt;
222       Value underlyingValue = underlying->getValue().getUnderlyingValue();
223       if (underlyingValue == value)
224         break;
225       value = underlyingValue;
226     } while (true);
227     return value;
228   }
229 };
230 
231 } // namespace test
232 } // namespace dataflow
233 } // namespace mlir
234