xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of tests passes exercising dense forward data flow analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "TestDenseDataFlowAnalysis.h"
14 #include "TestDialect.h"
15 #include "TestOps.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include <optional>
24 
25 using namespace mlir;
26 using namespace mlir::dataflow;
27 using namespace mlir::dataflow::test;
28 
29 namespace {
30 
31 /// This lattice represents, for a given memory resource, the potential last
32 /// operations that modified the resource.
33 class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
34 public:
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
36 
37   using AbstractDenseLattice::AbstractDenseLattice;
38 
39   /// Join the last modifications.
40   ChangeResult join(const AbstractDenseLattice &lattice) override {
41     return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
42         static_cast<const LastModification &>(lattice)));
43   }
44 
45   void print(raw_ostream &os) const override {
46     return AccessLatticeBase::print(os);
47   }
48 };
49 
50 class LastModifiedAnalysis
51     : public DenseForwardDataFlowAnalysis<LastModification> {
52 public:
53   explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
54       : DenseForwardDataFlowAnalysis(solver),
55         assumeFuncWrites(assumeFuncWrites) {}
56 
57   /// Visit an operation. If the operation has no memory effects, then the state
58   /// is propagated with no change. If the operation allocates a resource, then
59   /// its reaching definitions is set to empty. If the operation writes to a
60   /// resource, then its reaching definition is set to the written value.
61   LogicalResult visitOperation(Operation *op, const LastModification &before,
62                                LastModification *after) override;
63 
64   void visitCallControlFlowTransfer(CallOpInterface call,
65                                     CallControlFlowAction action,
66                                     const LastModification &before,
67                                     LastModification *after) override;
68 
69   void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
70                                             std::optional<unsigned> regionFrom,
71                                             std::optional<unsigned> regionTo,
72                                             const LastModification &before,
73                                             LastModification *after) override;
74 
75   /// At an entry point, the last modifications of all memory resources are
76   /// unknown.
77   void setToEntryState(LastModification *lattice) override {
78     propagateIfChanged(lattice, lattice->reset());
79   }
80 
81 private:
82   const bool assumeFuncWrites;
83 };
84 } // end anonymous namespace
85 
86 LogicalResult LastModifiedAnalysis::visitOperation(
87     Operation *op, const LastModification &before, LastModification *after) {
88   auto memory = dyn_cast<MemoryEffectOpInterface>(op);
89   // If we can't reason about the memory effects, then conservatively assume we
90   // can't deduce anything about the last modifications.
91   if (!memory) {
92     setToEntryState(after);
93     return success();
94   }
95 
96   SmallVector<MemoryEffects::EffectInstance> effects;
97   memory.getEffects(effects);
98 
99   // First, check if all underlying values are already known. Otherwise, avoid
100   // propagating and stay in the "undefined" state to avoid incorrectly
101   // propagating values that may be overwritten later on as that could be
102   // problematic for convergence based on monotonicity of lattice updates.
103   SmallVector<Value> underlyingValues;
104   underlyingValues.reserve(effects.size());
105   for (const auto &effect : effects) {
106     Value value = effect.getValue();
107 
108     // If we see an effect on anything other than a value, assume we can't
109     // deduce anything about the last modifications.
110     if (!value) {
111       setToEntryState(after);
112       return success();
113     }
114 
115     // If we cannot find the underlying value, we shouldn't just propagate the
116     // effects through, return the pessimistic state.
117     std::optional<Value> underlyingValue =
118         UnderlyingValueAnalysis::getMostUnderlyingValue(
119             value, [&](Value value) {
120               return getOrCreateFor<UnderlyingValueLattice>(
121                   getProgramPointAfter(op), value);
122             });
123 
124     // If the underlying value is not yet known, don't propagate yet.
125     if (!underlyingValue)
126       return success();
127 
128     underlyingValues.push_back(*underlyingValue);
129   }
130 
131   // Update the state when all underlying values are known.
132   ChangeResult result = after->join(before);
133   for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
134     // If the underlying value is known to be unknown, set to fixpoint state.
135     if (!value) {
136       setToEntryState(after);
137       return success();
138     }
139 
140     // Nothing to do for reads.
141     if (isa<MemoryEffects::Read>(effect.getEffect()))
142       continue;
143 
144     result |= after->set(value, op);
145   }
146   propagateIfChanged(after, result);
147   return success();
148 }
149 
150 void LastModifiedAnalysis::visitCallControlFlowTransfer(
151     CallOpInterface call, CallControlFlowAction action,
152     const LastModification &before, LastModification *after) {
153   if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
154     SmallVector<Value> underlyingValues;
155     underlyingValues.reserve(call->getNumOperands());
156     for (Value operand : call.getArgOperands()) {
157       std::optional<Value> underlyingValue =
158           UnderlyingValueAnalysis::getMostUnderlyingValue(
159               operand, [&](Value value) {
160                 return getOrCreateFor<UnderlyingValueLattice>(
161                     getProgramPointAfter(call.getOperation()), value);
162               });
163       if (!underlyingValue)
164         return;
165       underlyingValues.push_back(*underlyingValue);
166     }
167 
168     ChangeResult result = after->join(before);
169     for (Value operand : underlyingValues)
170       result |= after->set(operand, call);
171     return propagateIfChanged(after, result);
172   }
173   auto testCallAndStore =
174       dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
175   if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
176                             testCallAndStore.getStoreBeforeCall()) ||
177                            (action == CallControlFlowAction::ExitCallee &&
178                             !testCallAndStore.getStoreBeforeCall()))) {
179     (void)visitOperation(call, before, after);
180     return;
181   }
182   AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
183       call, action, before, after);
184 }
185 
186 void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
187     RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
188     std::optional<unsigned> regionTo, const LastModification &before,
189     LastModification *after) {
190   auto defaultHandling = [&]() {
191     AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
192         branch, regionFrom, regionTo, before, after);
193   };
194   TypeSwitch<Operation *>(branch.getOperation())
195       .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
196           [=](auto storeWithRegion) {
197             if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
198                 (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
199               (void)visitOperation(branch, before, after);
200             defaultHandling();
201           })
202       .Default([=](auto) { defaultHandling(); });
203 }
204 
205 namespace {
206 struct TestLastModifiedPass
207     : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
208   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
209 
210   TestLastModifiedPass() = default;
211   TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
212     interprocedural = other.interprocedural;
213     assumeFuncWrites = other.assumeFuncWrites;
214   }
215 
216   StringRef getArgument() const override { return "test-last-modified"; }
217 
218   Option<bool> interprocedural{
219       *this, "interprocedural", llvm::cl::init(true),
220       llvm::cl::desc("perform interprocedural analysis")};
221   Option<bool> assumeFuncWrites{
222       *this, "assume-func-writes", llvm::cl::init(false),
223       llvm::cl::desc(
224           "assume external functions have write effect on all arguments")};
225 
226   void runOnOperation() override {
227     Operation *op = getOperation();
228 
229     DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
230     solver.load<DeadCodeAnalysis>();
231     solver.load<SparseConstantPropagation>();
232     solver.load<LastModifiedAnalysis>(assumeFuncWrites);
233     solver.load<UnderlyingValueAnalysis>();
234     if (failed(solver.initializeAndRun(op)))
235       return signalPassFailure();
236 
237     raw_ostream &os = llvm::errs();
238 
239     // Note that if the underlying value could not be computed or is unknown, we
240     // conservatively treat the result also unknown.
241     op->walk([&](Operation *op) {
242       auto tag = op->getAttrOfType<StringAttr>("tag");
243       if (!tag)
244         return;
245       os << "test_tag: " << tag.getValue() << ":\n";
246       const LastModification *lastMods =
247           solver.lookupState<LastModification>(solver.getProgramPointAfter(op));
248       assert(lastMods && "expected a dense lattice");
249       for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
250         os << " operand #" << index << "\n";
251         std::optional<Value> underlyingValue =
252             UnderlyingValueAnalysis::getMostUnderlyingValue(
253                 operand, [&](Value value) {
254                   return solver.lookupState<UnderlyingValueLattice>(value);
255                 });
256         if (!underlyingValue) {
257           os << " - <unknown>\n";
258           continue;
259         }
260         Value value = *underlyingValue;
261         assert(value && "expected an underlying value");
262         if (const AdjacentAccess *lastMod =
263                 lastMods->getAdjacentAccess(value)) {
264           if (!lastMod->isKnown()) {
265             os << " - <unknown>\n";
266           } else {
267             for (Operation *lastModifier : lastMod->get()) {
268               if (auto tagName =
269                       lastModifier->getAttrOfType<StringAttr>("tag_name")) {
270                 os << "  - " << tagName.getValue() << "\n";
271               } else {
272                 os << "  - " << lastModifier->getName() << "\n";
273               }
274             }
275           }
276         } else {
277           os << "  - <unknown>\n";
278         }
279       }
280     });
281   }
282 };
283 } // end anonymous namespace
284 
285 namespace mlir {
286 namespace test {
287 void registerTestLastModifiedPass() {
288   PassRegistration<TestLastModifiedPass>();
289 }
290 } // end namespace test
291 } // end namespace mlir
292