xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1 //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Test pass for backward dense dataflow analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "TestDenseDataFlowAnalysis.h"
14 #include "TestDialect.h"
15 #include "TestOps.h"
16 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19 #include "mlir/Analysis/DataFlowFramework.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/SymbolTable.h"
22 #include "mlir/Interfaces/CallInterfaces.h"
23 #include "mlir/Interfaces/ControlFlowInterfaces.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/TypeID.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 using namespace mlir;
30 using namespace mlir::dataflow;
31 using namespace mlir::dataflow::test;
32 
33 namespace {
34 
35 class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
36 public:
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
38 
39   using dataflow::AbstractDenseLattice::AbstractDenseLattice;
40 
41   ChangeResult meet(const AbstractDenseLattice &lattice) override {
42     return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
43         static_cast<const NextAccess &>(lattice)));
44   }
45 
46   void print(raw_ostream &os) const override {
47     return AccessLatticeBase::print(os);
48   }
49 };
50 
51 class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
52 public:
53   NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
54                      bool assumeFuncReads = false)
55       : DenseBackwardDataFlowAnalysis(solver, symbolTable),
56         assumeFuncReads(assumeFuncReads) {}
57 
58   LogicalResult visitOperation(Operation *op, const NextAccess &after,
59                                NextAccess *before) override;
60 
61   void visitCallControlFlowTransfer(CallOpInterface call,
62                                     CallControlFlowAction action,
63                                     const NextAccess &after,
64                                     NextAccess *before) override;
65 
66   void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
67                                             RegionBranchPoint regionFrom,
68                                             RegionBranchPoint regionTo,
69                                             const NextAccess &after,
70                                             NextAccess *before) override;
71 
72   // TODO: this isn't ideal for the analysis. When there is no next access, it
73   // means "we don't know what the next access is" rather than "there is no next
74   // access". But it's unclear how to differentiate the two cases...
75   void setToExitState(NextAccess *lattice) override {
76     propagateIfChanged(lattice, lattice->setKnownToUnknown());
77   }
78 
79   const bool assumeFuncReads;
80 };
81 } // namespace
82 
83 LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
84                                                  const NextAccess &after,
85                                                  NextAccess *before) {
86   auto memory = dyn_cast<MemoryEffectOpInterface>(op);
87   // If we can't reason about the memory effects, conservatively assume we can't
88   // say anything about the next access.
89   if (!memory) {
90     setToExitState(before);
91     return success();
92   }
93 
94   SmallVector<MemoryEffects::EffectInstance> effects;
95   memory.getEffects(effects);
96 
97   // First, check if all underlying values are already known. Otherwise, avoid
98   // propagating and stay in the "undefined" state to avoid incorrectly
99   // propagating values that may be overwritten later on as that could be
100   // problematic for convergence based on monotonicity of lattice updates.
101   SmallVector<Value> underlyingValues;
102   underlyingValues.reserve(effects.size());
103   for (const MemoryEffects::EffectInstance &effect : effects) {
104     Value value = effect.getValue();
105 
106     // Effects with unspecified value are treated conservatively and we cannot
107     // assume anything about the next access.
108     if (!value) {
109       setToExitState(before);
110       return success();
111     }
112 
113     // If cannot find the most underlying value, we cannot assume anything about
114     // the next accesses.
115     std::optional<Value> underlyingValue =
116         UnderlyingValueAnalysis::getMostUnderlyingValue(
117             value, [&](Value value) {
118               return getOrCreateFor<UnderlyingValueLattice>(
119                   getProgramPointBefore(op), value);
120             });
121 
122     // If the underlying value is not known yet, don't propagate.
123     if (!underlyingValue)
124       return success();
125 
126     underlyingValues.push_back(*underlyingValue);
127   }
128 
129   // Update the state if all underlying values are known.
130   ChangeResult result = before->meet(after);
131   for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
132     // If the underlying value is known to be unknown, set to fixpoint.
133     if (!value) {
134       setToExitState(before);
135       return success();
136     }
137 
138     result |= before->set(value, op);
139   }
140   propagateIfChanged(before, result);
141   return success();
142 }
143 
144 void NextAccessAnalysis::visitCallControlFlowTransfer(
145     CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
146     NextAccess *before) {
147   if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
148     SmallVector<Value> underlyingValues;
149     underlyingValues.reserve(call->getNumOperands());
150     for (Value operand : call.getArgOperands()) {
151       std::optional<Value> underlyingValue =
152           UnderlyingValueAnalysis::getMostUnderlyingValue(
153               operand, [&](Value value) {
154                 return getOrCreateFor<UnderlyingValueLattice>(
155                     getProgramPointBefore(call.getOperation()), value);
156               });
157       if (!underlyingValue)
158         return;
159       underlyingValues.push_back(*underlyingValue);
160     }
161 
162     ChangeResult result = before->meet(after);
163     for (Value operand : underlyingValues) {
164       result |= before->set(operand, call);
165     }
166     return propagateIfChanged(before, result);
167   }
168   auto testCallAndStore =
169       dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
170   if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
171                             testCallAndStore.getStoreBeforeCall()) ||
172                            (action == CallControlFlowAction::ExitCallee &&
173                             !testCallAndStore.getStoreBeforeCall()))) {
174     (void)visitOperation(call, after, before);
175   } else {
176     AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
177         call, action, after, before);
178   }
179 }
180 
181 void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
182     RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
183     RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
184   auto testStoreWithARegion =
185       dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
186 
187   if (testStoreWithARegion &&
188       ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
189        (regionFrom.isParent() &&
190         testStoreWithARegion.getStoreBeforeRegion()))) {
191     (void)visitOperation(branch, static_cast<const NextAccess &>(after),
192                          static_cast<NextAccess *>(before));
193   } else {
194     propagateIfChanged(before, before->meet(after));
195   }
196 }
197 
198 namespace {
199 struct TestNextAccessPass
200     : public PassWrapper<TestNextAccessPass, OperationPass<>> {
201   TestNextAccessPass() = default;
202   TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
203     interprocedural = other.interprocedural;
204     assumeFuncReads = other.assumeFuncReads;
205   }
206 
207   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
208 
209   StringRef getArgument() const override { return "test-next-access"; }
210 
211   Option<bool> interprocedural{
212       *this, "interprocedural", llvm::cl::init(true),
213       llvm::cl::desc("perform interprocedural analysis")};
214   Option<bool> assumeFuncReads{
215       *this, "assume-func-reads", llvm::cl::init(false),
216       llvm::cl::desc(
217           "assume external functions have read effect on all arguments")};
218 
219   static constexpr llvm::StringLiteral kTagAttrName = "name";
220   static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
221   static constexpr llvm::StringLiteral kAtEntryPointAttrName =
222       "next_at_entry_point";
223 
224   static Attribute makeNextAccessAttribute(Operation *op,
225                                            const DataFlowSolver &solver,
226                                            const NextAccess *nextAccess) {
227     if (!nextAccess)
228       return StringAttr::get(op->getContext(), "not computed");
229 
230     // Note that if the underlying value could not be computed or is unknown, we
231     // conservatively treat the result also unknown.
232     SmallVector<Attribute> attrs;
233     for (Value operand : op->getOperands()) {
234       std::optional<Value> underlyingValue =
235           UnderlyingValueAnalysis::getMostUnderlyingValue(
236               operand, [&](Value value) {
237                 return solver.lookupState<UnderlyingValueLattice>(value);
238               });
239       if (!underlyingValue) {
240         attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
241         continue;
242       }
243       Value value = *underlyingValue;
244       const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
245       if (!nextAcc || !nextAcc->isKnown()) {
246         attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
247         continue;
248       }
249 
250       SmallVector<Attribute> innerAttrs;
251       innerAttrs.reserve(nextAcc->get().size());
252       for (Operation *nextAccOp : nextAcc->get()) {
253         if (auto nextAccTag =
254                 nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
255           innerAttrs.push_back(nextAccTag);
256           continue;
257         }
258         std::string repr;
259         llvm::raw_string_ostream os(repr);
260         nextAccOp->print(os);
261         innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
262       }
263       attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
264     }
265     return ArrayAttr::get(op->getContext(), attrs);
266   }
267 
268   void runOnOperation() override {
269     Operation *op = getOperation();
270     SymbolTableCollection symbolTable;
271 
272     auto config = DataFlowConfig().setInterprocedural(interprocedural);
273     DataFlowSolver solver(config);
274     solver.load<DeadCodeAnalysis>();
275     solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
276     solver.load<SparseConstantPropagation>();
277     solver.load<UnderlyingValueAnalysis>();
278     if (failed(solver.initializeAndRun(op))) {
279       emitError(op->getLoc(), "dataflow solver failed");
280       return signalPassFailure();
281     }
282     op->walk([&](Operation *op) {
283       auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
284       if (!tag)
285         return;
286 
287       const NextAccess *nextAccess =
288           solver.lookupState<NextAccess>(solver.getProgramPointAfter(op));
289       op->setAttr(kNextAccessAttrName,
290                   makeNextAccessAttribute(op, solver, nextAccess));
291 
292       auto iface = dyn_cast<RegionBranchOpInterface>(op);
293       if (!iface)
294         return;
295 
296       SmallVector<Attribute> entryPointNextAccess;
297       SmallVector<RegionSuccessor> regionSuccessors;
298       iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
299       for (const RegionSuccessor &successor : regionSuccessors) {
300         if (!successor.getSuccessor() || successor.getSuccessor()->empty())
301           continue;
302         Block &successorBlock = successor.getSuccessor()->front();
303         ProgramPoint *successorPoint =
304             solver.getProgramPointBefore(&successorBlock);
305         entryPointNextAccess.push_back(makeNextAccessAttribute(
306             op, solver, solver.lookupState<NextAccess>(successorPoint)));
307       }
308       op->setAttr(kAtEntryPointAttrName,
309                   ArrayAttr::get(op->getContext(), entryPointNextAccess));
310     });
311   }
312 };
313 } // namespace
314 
315 namespace mlir::test {
316 void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
317 } // namespace mlir::test
318