1 //===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implementation of tests passes exercising dense forward data flow analysis. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "TestDenseDataFlowAnalysis.h" 14 #include "TestDialect.h" 15 #include "TestOps.h" 16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 17 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 18 #include "mlir/Analysis/DataFlow/DenseAnalysis.h" 19 #include "mlir/Interfaces/SideEffectInterfaces.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Support/LLVM.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include <optional> 24 25 using namespace mlir; 26 using namespace mlir::dataflow; 27 using namespace mlir::dataflow::test; 28 29 namespace { 30 31 /// This lattice represents, for a given memory resource, the potential last 32 /// operations that modified the resource. 33 class LastModification : public AbstractDenseLattice, public AccessLatticeBase { 34 public: 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) 36 37 using AbstractDenseLattice::AbstractDenseLattice; 38 39 /// Join the last modifications. 40 ChangeResult join(const AbstractDenseLattice &lattice) override { 41 return AccessLatticeBase::merge(static_cast<AccessLatticeBase>( 42 static_cast<const LastModification &>(lattice))); 43 } 44 45 void print(raw_ostream &os) const override { 46 return AccessLatticeBase::print(os); 47 } 48 }; 49 50 class LastModifiedAnalysis 51 : public DenseForwardDataFlowAnalysis<LastModification> { 52 public: 53 explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) 54 : DenseForwardDataFlowAnalysis(solver), 55 assumeFuncWrites(assumeFuncWrites) {} 56 57 /// Visit an operation. If the operation has no memory effects, then the state 58 /// is propagated with no change. If the operation allocates a resource, then 59 /// its reaching definitions is set to empty. If the operation writes to a 60 /// resource, then its reaching definition is set to the written value. 61 LogicalResult visitOperation(Operation *op, const LastModification &before, 62 LastModification *after) override; 63 64 void visitCallControlFlowTransfer(CallOpInterface call, 65 CallControlFlowAction action, 66 const LastModification &before, 67 LastModification *after) override; 68 69 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, 70 std::optional<unsigned> regionFrom, 71 std::optional<unsigned> regionTo, 72 const LastModification &before, 73 LastModification *after) override; 74 75 /// At an entry point, the last modifications of all memory resources are 76 /// unknown. 77 void setToEntryState(LastModification *lattice) override { 78 propagateIfChanged(lattice, lattice->reset()); 79 } 80 81 private: 82 const bool assumeFuncWrites; 83 }; 84 } // end anonymous namespace 85 86 LogicalResult LastModifiedAnalysis::visitOperation( 87 Operation *op, const LastModification &before, LastModification *after) { 88 auto memory = dyn_cast<MemoryEffectOpInterface>(op); 89 // If we can't reason about the memory effects, then conservatively assume we 90 // can't deduce anything about the last modifications. 91 if (!memory) { 92 setToEntryState(after); 93 return success(); 94 } 95 96 SmallVector<MemoryEffects::EffectInstance> effects; 97 memory.getEffects(effects); 98 99 // First, check if all underlying values are already known. Otherwise, avoid 100 // propagating and stay in the "undefined" state to avoid incorrectly 101 // propagating values that may be overwritten later on as that could be 102 // problematic for convergence based on monotonicity of lattice updates. 103 SmallVector<Value> underlyingValues; 104 underlyingValues.reserve(effects.size()); 105 for (const auto &effect : effects) { 106 Value value = effect.getValue(); 107 108 // If we see an effect on anything other than a value, assume we can't 109 // deduce anything about the last modifications. 110 if (!value) { 111 setToEntryState(after); 112 return success(); 113 } 114 115 // If we cannot find the underlying value, we shouldn't just propagate the 116 // effects through, return the pessimistic state. 117 std::optional<Value> underlyingValue = 118 UnderlyingValueAnalysis::getMostUnderlyingValue( 119 value, [&](Value value) { 120 return getOrCreateFor<UnderlyingValueLattice>( 121 getProgramPointAfter(op), value); 122 }); 123 124 // If the underlying value is not yet known, don't propagate yet. 125 if (!underlyingValue) 126 return success(); 127 128 underlyingValues.push_back(*underlyingValue); 129 } 130 131 // Update the state when all underlying values are known. 132 ChangeResult result = after->join(before); 133 for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { 134 // If the underlying value is known to be unknown, set to fixpoint state. 135 if (!value) { 136 setToEntryState(after); 137 return success(); 138 } 139 140 // Nothing to do for reads. 141 if (isa<MemoryEffects::Read>(effect.getEffect())) 142 continue; 143 144 result |= after->set(value, op); 145 } 146 propagateIfChanged(after, result); 147 return success(); 148 } 149 150 void LastModifiedAnalysis::visitCallControlFlowTransfer( 151 CallOpInterface call, CallControlFlowAction action, 152 const LastModification &before, LastModification *after) { 153 if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) { 154 SmallVector<Value> underlyingValues; 155 underlyingValues.reserve(call->getNumOperands()); 156 for (Value operand : call.getArgOperands()) { 157 std::optional<Value> underlyingValue = 158 UnderlyingValueAnalysis::getMostUnderlyingValue( 159 operand, [&](Value value) { 160 return getOrCreateFor<UnderlyingValueLattice>( 161 getProgramPointAfter(call.getOperation()), value); 162 }); 163 if (!underlyingValue) 164 return; 165 underlyingValues.push_back(*underlyingValue); 166 } 167 168 ChangeResult result = after->join(before); 169 for (Value operand : underlyingValues) 170 result |= after->set(operand, call); 171 return propagateIfChanged(after, result); 172 } 173 auto testCallAndStore = 174 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); 175 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && 176 testCallAndStore.getStoreBeforeCall()) || 177 (action == CallControlFlowAction::ExitCallee && 178 !testCallAndStore.getStoreBeforeCall()))) { 179 (void)visitOperation(call, before, after); 180 return; 181 } 182 AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer( 183 call, action, before, after); 184 } 185 186 void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer( 187 RegionBranchOpInterface branch, std::optional<unsigned> regionFrom, 188 std::optional<unsigned> regionTo, const LastModification &before, 189 LastModification *after) { 190 auto defaultHandling = [&]() { 191 AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( 192 branch, regionFrom, regionTo, before, after); 193 }; 194 TypeSwitch<Operation *>(branch.getOperation()) 195 .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>( 196 [=](auto storeWithRegion) { 197 if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) || 198 (!regionFrom && storeWithRegion.getStoreBeforeRegion())) 199 (void)visitOperation(branch, before, after); 200 defaultHandling(); 201 }) 202 .Default([=](auto) { defaultHandling(); }); 203 } 204 205 namespace { 206 struct TestLastModifiedPass 207 : public PassWrapper<TestLastModifiedPass, OperationPass<>> { 208 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) 209 210 TestLastModifiedPass() = default; 211 TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) { 212 interprocedural = other.interprocedural; 213 assumeFuncWrites = other.assumeFuncWrites; 214 } 215 216 StringRef getArgument() const override { return "test-last-modified"; } 217 218 Option<bool> interprocedural{ 219 *this, "interprocedural", llvm::cl::init(true), 220 llvm::cl::desc("perform interprocedural analysis")}; 221 Option<bool> assumeFuncWrites{ 222 *this, "assume-func-writes", llvm::cl::init(false), 223 llvm::cl::desc( 224 "assume external functions have write effect on all arguments")}; 225 226 void runOnOperation() override { 227 Operation *op = getOperation(); 228 229 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); 230 solver.load<DeadCodeAnalysis>(); 231 solver.load<SparseConstantPropagation>(); 232 solver.load<LastModifiedAnalysis>(assumeFuncWrites); 233 solver.load<UnderlyingValueAnalysis>(); 234 if (failed(solver.initializeAndRun(op))) 235 return signalPassFailure(); 236 237 raw_ostream &os = llvm::errs(); 238 239 // Note that if the underlying value could not be computed or is unknown, we 240 // conservatively treat the result also unknown. 241 op->walk([&](Operation *op) { 242 auto tag = op->getAttrOfType<StringAttr>("tag"); 243 if (!tag) 244 return; 245 os << "test_tag: " << tag.getValue() << ":\n"; 246 const LastModification *lastMods = 247 solver.lookupState<LastModification>(solver.getProgramPointAfter(op)); 248 assert(lastMods && "expected a dense lattice"); 249 for (auto [index, operand] : llvm::enumerate(op->getOperands())) { 250 os << " operand #" << index << "\n"; 251 std::optional<Value> underlyingValue = 252 UnderlyingValueAnalysis::getMostUnderlyingValue( 253 operand, [&](Value value) { 254 return solver.lookupState<UnderlyingValueLattice>(value); 255 }); 256 if (!underlyingValue) { 257 os << " - <unknown>\n"; 258 continue; 259 } 260 Value value = *underlyingValue; 261 assert(value && "expected an underlying value"); 262 if (const AdjacentAccess *lastMod = 263 lastMods->getAdjacentAccess(value)) { 264 if (!lastMod->isKnown()) { 265 os << " - <unknown>\n"; 266 } else { 267 for (Operation *lastModifier : lastMod->get()) { 268 if (auto tagName = 269 lastModifier->getAttrOfType<StringAttr>("tag_name")) { 270 os << " - " << tagName.getValue() << "\n"; 271 } else { 272 os << " - " << lastModifier->getName() << "\n"; 273 } 274 } 275 } 276 } else { 277 os << " - <unknown>\n"; 278 } 279 } 280 }); 281 } 282 }; 283 } // end anonymous namespace 284 285 namespace mlir { 286 namespace test { 287 void registerTestLastModifiedPass() { 288 PassRegistration<TestLastModifiedPass>(); 289 } 290 } // end namespace test 291 } // end namespace mlir 292