1 //===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h" 10 #include "mlir/Analysis/DataFlowFramework.h" 11 #include "mlir/IR/Value.h" 12 #include "llvm/ADT/DenseMap.h" 13 #include "llvm/Support/raw_ostream.h" 14 #include <optional> 15 16 namespace mlir { 17 namespace dataflow { 18 namespace test { 19 20 /// This lattice represents a single underlying value for an SSA value. 21 class UnderlyingValue { 22 public: 23 /// Create an underlying value state with a known underlying value. 24 explicit UnderlyingValue(std::optional<Value> underlyingValue = std::nullopt) 25 : underlyingValue(underlyingValue) {} 26 27 /// Whether the state is uninitialized. 28 bool isUninitialized() const { return !underlyingValue.has_value(); } 29 30 /// Returns the underlying value. 31 Value getUnderlyingValue() const { 32 assert(!isUninitialized()); 33 return *underlyingValue; 34 } 35 36 /// Join two underlying values. If there are conflicting underlying values, 37 /// go to the pessimistic value. 38 static UnderlyingValue join(const UnderlyingValue &lhs, 39 const UnderlyingValue &rhs) { 40 if (lhs.isUninitialized()) 41 return rhs; 42 if (rhs.isUninitialized()) 43 return lhs; 44 return lhs.underlyingValue == rhs.underlyingValue 45 ? lhs 46 : UnderlyingValue(Value{}); 47 } 48 49 /// Compare underlying values. 50 bool operator==(const UnderlyingValue &rhs) const { 51 return underlyingValue == rhs.underlyingValue; 52 } 53 54 void print(raw_ostream &os) const { os << underlyingValue; } 55 56 private: 57 std::optional<Value> underlyingValue; 58 }; 59 60 class AdjacentAccess { 61 public: 62 using DeterministicSetVector = 63 SetVector<Operation *, SmallVector<Operation *, 2>, 64 SmallPtrSet<Operation *, 2>>; 65 66 ArrayRef<Operation *> get() const { return accesses.getArrayRef(); } 67 bool isKnown() const { return !unknown; } 68 69 ChangeResult merge(const AdjacentAccess &other) { 70 if (unknown) 71 return ChangeResult::NoChange; 72 if (other.unknown) { 73 unknown = true; 74 accesses.clear(); 75 return ChangeResult::Change; 76 } 77 78 size_t sizeBefore = accesses.size(); 79 accesses.insert(other.accesses.begin(), other.accesses.end()); 80 return accesses.size() == sizeBefore ? ChangeResult::NoChange 81 : ChangeResult::Change; 82 } 83 84 ChangeResult set(Operation *op) { 85 if (!unknown && accesses.size() == 1 && *accesses.begin() == op) 86 return ChangeResult::NoChange; 87 88 unknown = false; 89 accesses.clear(); 90 accesses.insert(op); 91 return ChangeResult::Change; 92 } 93 94 ChangeResult setUnknown() { 95 if (unknown) 96 return ChangeResult::NoChange; 97 98 accesses.clear(); 99 unknown = true; 100 return ChangeResult::Change; 101 } 102 103 bool operator==(const AdjacentAccess &other) const { 104 return unknown == other.unknown && accesses == other.accesses; 105 } 106 107 bool operator!=(const AdjacentAccess &other) const { 108 return !operator==(other); 109 } 110 111 private: 112 bool unknown = false; 113 DeterministicSetVector accesses; 114 }; 115 116 /// This lattice represents, for a given memory resource, the potential last 117 /// operations that modified the resource. 118 class AccessLatticeBase { 119 public: 120 /// Clear all modifications. 121 ChangeResult reset() { 122 if (adjAccesses.empty()) 123 return ChangeResult::NoChange; 124 adjAccesses.clear(); 125 return ChangeResult::Change; 126 } 127 128 /// Join the last modifications. 129 ChangeResult merge(const AccessLatticeBase &rhs) { 130 ChangeResult result = ChangeResult::NoChange; 131 for (const auto &mod : rhs.adjAccesses) { 132 AdjacentAccess &lhsMod = adjAccesses[mod.first]; 133 result |= lhsMod.merge(mod.second); 134 } 135 return result; 136 } 137 138 /// Set the last modification of a value. 139 ChangeResult set(Value value, Operation *op) { 140 AdjacentAccess &lastMod = adjAccesses[value]; 141 return lastMod.set(op); 142 } 143 144 ChangeResult setKnownToUnknown() { 145 ChangeResult result = ChangeResult::NoChange; 146 for (auto &[value, adjacent] : adjAccesses) 147 result |= adjacent.setUnknown(); 148 return result; 149 } 150 151 /// Get the adjacent accesses to a value. Returns std::nullopt if they 152 /// are not known. 153 const AdjacentAccess *getAdjacentAccess(Value value) const { 154 auto it = adjAccesses.find(value); 155 if (it == adjAccesses.end()) 156 return nullptr; 157 return &it->getSecond(); 158 } 159 160 void print(raw_ostream &os) const { 161 for (const auto &lastMod : adjAccesses) { 162 os << lastMod.first << ":\n"; 163 if (!lastMod.second.isKnown()) { 164 os << " <unknown>\n"; 165 return; 166 } 167 for (Operation *op : lastMod.second.get()) 168 os << " " << *op << "\n"; 169 } 170 } 171 172 private: 173 /// The potential adjacent accesses to a memory resource. Use a set vector to 174 /// keep the results deterministic. 175 DenseMap<Value, AdjacentAccess> adjAccesses; 176 }; 177 178 /// Define the lattice class explicitly to provide a type ID. 179 struct UnderlyingValueLattice : public Lattice<UnderlyingValue> { 180 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) 181 using Lattice::Lattice; 182 }; 183 184 /// An analysis that uses forwarding of values along control-flow and callgraph 185 /// edges to determine single underlying values for block arguments. This 186 /// analysis exists so that the test analysis and pass can test the behaviour of 187 /// the dense data-flow analysis on the callgraph. 188 class UnderlyingValueAnalysis 189 : public SparseForwardDataFlowAnalysis<UnderlyingValueLattice> { 190 public: 191 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; 192 193 /// The underlying value of the results of an operation are not known. 194 LogicalResult 195 visitOperation(Operation *op, 196 ArrayRef<const UnderlyingValueLattice *> operands, 197 ArrayRef<UnderlyingValueLattice *> results) override { 198 // Hook to test error propagation from visitOperation. 199 if (op->hasAttr("always_fail")) 200 return op->emitError("this op is always fails"); 201 202 setAllToEntryStates(results); 203 return success(); 204 } 205 206 /// At an entry point, the underlying value of a value is itself. 207 void setToEntryState(UnderlyingValueLattice *lattice) override { 208 propagateIfChanged(lattice, 209 lattice->join(UnderlyingValue{lattice->getAnchor()})); 210 } 211 212 /// Look for the most underlying value of a value. 213 static std::optional<Value> 214 getMostUnderlyingValue(Value value, 215 function_ref<const UnderlyingValueLattice *(Value)> 216 getUnderlyingValueFn) { 217 const UnderlyingValueLattice *underlying; 218 do { 219 underlying = getUnderlyingValueFn(value); 220 if (!underlying || underlying->getValue().isUninitialized()) 221 return std::nullopt; 222 Value underlyingValue = underlying->getValue().getUnderlyingValue(); 223 if (underlyingValue == value) 224 break; 225 value = underlyingValue; 226 } while (true); 227 return value; 228 } 229 }; 230 231 } // namespace test 232 } // namespace dataflow 233 } // namespace mlir 234