xref: /llvm-project/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (revision 4b3f251bada55cfc20a2c72321fa0bbfd7a759d5)
1b2b7efb9SAlex Zinenko //===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
2b2b7efb9SAlex Zinenko //
3b2b7efb9SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b2b7efb9SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5b2b7efb9SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b2b7efb9SAlex Zinenko //
7b2b7efb9SAlex Zinenko //===----------------------------------------------------------------------===//
8b2b7efb9SAlex Zinenko 
9b2b7efb9SAlex Zinenko #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10b2b7efb9SAlex Zinenko #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11b2b7efb9SAlex Zinenko #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12b2b7efb9SAlex Zinenko #include "mlir/Dialect/MemRef/IR/MemRef.h"
13b2b7efb9SAlex Zinenko #include "mlir/Interfaces/SideEffectInterfaces.h"
14b2b7efb9SAlex Zinenko #include "mlir/Pass/Pass.h"
15b2b7efb9SAlex Zinenko 
16b2b7efb9SAlex Zinenko using namespace mlir;
17b2b7efb9SAlex Zinenko using namespace mlir::dataflow;
18b2b7efb9SAlex Zinenko 
19b2b7efb9SAlex Zinenko namespace {
20b2b7efb9SAlex Zinenko 
21d1cff36eSAndi Drebes /// Lattice value storing the a set of memory resources that something
22d1cff36eSAndi Drebes /// is written to.
23d1cff36eSAndi Drebes struct WrittenToLatticeValue {
24d1cff36eSAndi Drebes   bool operator==(const WrittenToLatticeValue &other) {
25d1cff36eSAndi Drebes     return this->writes == other.writes;
26b2b7efb9SAlex Zinenko   }
27d1cff36eSAndi Drebes 
28d1cff36eSAndi Drebes   static WrittenToLatticeValue meet(const WrittenToLatticeValue &lhs,
29d1cff36eSAndi Drebes                                     const WrittenToLatticeValue &rhs) {
30d1cff36eSAndi Drebes     WrittenToLatticeValue res = lhs;
31d1cff36eSAndi Drebes     (void)res.addWrites(rhs.writes);
32d1cff36eSAndi Drebes 
33d1cff36eSAndi Drebes     return res;
34d1cff36eSAndi Drebes   }
35d1cff36eSAndi Drebes 
36d1cff36eSAndi Drebes   static WrittenToLatticeValue join(const WrittenToLatticeValue &lhs,
37d1cff36eSAndi Drebes                                     const WrittenToLatticeValue &rhs) {
38d1cff36eSAndi Drebes     // Should not be triggered by this test, but required by `Lattice<T>`
39d1cff36eSAndi Drebes     llvm_unreachable("Join should not be triggered by this test");
40d1cff36eSAndi Drebes   }
41d1cff36eSAndi Drebes 
42b2b7efb9SAlex Zinenko   ChangeResult addWrites(const SetVector<StringAttr> &writes) {
43b2b7efb9SAlex Zinenko     int sizeBefore = this->writes.size();
44b2b7efb9SAlex Zinenko     this->writes.insert(writes.begin(), writes.end());
45b2b7efb9SAlex Zinenko     int sizeAfter = this->writes.size();
46b2b7efb9SAlex Zinenko     return sizeBefore == sizeAfter ? ChangeResult::NoChange
47b2b7efb9SAlex Zinenko                                    : ChangeResult::Change;
48b2b7efb9SAlex Zinenko   }
49d1cff36eSAndi Drebes 
50d1cff36eSAndi Drebes   void print(raw_ostream &os) const {
51d1cff36eSAndi Drebes     os << "[";
52d1cff36eSAndi Drebes     llvm::interleave(
53d1cff36eSAndi Drebes         writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
54d1cff36eSAndi Drebes     os << "]";
55b2b7efb9SAlex Zinenko   }
56b2b7efb9SAlex Zinenko 
57d1cff36eSAndi Drebes   void clear() { writes.clear(); }
58d1cff36eSAndi Drebes 
59b2b7efb9SAlex Zinenko   SetVector<StringAttr> writes;
60b2b7efb9SAlex Zinenko };
61b2b7efb9SAlex Zinenko 
62d1cff36eSAndi Drebes /// This lattice represents, for a given value, the set of memory resources that
63d1cff36eSAndi Drebes /// this value, or anything derived from this value, is potentially written to.
64d1cff36eSAndi Drebes struct WrittenTo : public Lattice<WrittenToLatticeValue> {
65d1cff36eSAndi Drebes   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
66d1cff36eSAndi Drebes   using Lattice::Lattice;
67d1cff36eSAndi Drebes };
68d1cff36eSAndi Drebes 
69b2b7efb9SAlex Zinenko /// An analysis that, by going backwards along the dataflow graph, annotates
70b2b7efb9SAlex Zinenko /// each value with all the memory resources it (or anything derived from it)
71b2b7efb9SAlex Zinenko /// is eventually written to.
72b2b7efb9SAlex Zinenko class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
73b2b7efb9SAlex Zinenko public:
7432a4e3fcSOleksandr "Alex" Zinenko   WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
7532a4e3fcSOleksandr "Alex" Zinenko                     bool assumeFuncWrites)
7632a4e3fcSOleksandr "Alex" Zinenko       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
7732a4e3fcSOleksandr "Alex" Zinenko         assumeFuncWrites(assumeFuncWrites) {}
78b2b7efb9SAlex Zinenko 
7915e915a4SIvan Butygin   LogicalResult visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
80b2b7efb9SAlex Zinenko                                ArrayRef<const WrittenTo *> results) override;
81b2b7efb9SAlex Zinenko 
82b2b7efb9SAlex Zinenko   void visitBranchOperand(OpOperand &operand) override;
83b2b7efb9SAlex Zinenko 
84232f8eadSSrishti Srivastava   void visitCallOperand(OpOperand &operand) override;
85232f8eadSSrishti Srivastava 
8632a4e3fcSOleksandr "Alex" Zinenko   void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
8732a4e3fcSOleksandr "Alex" Zinenko                          ArrayRef<const WrittenTo *> results) override;
8832a4e3fcSOleksandr "Alex" Zinenko 
89d1cff36eSAndi Drebes   void setToExitState(WrittenTo *lattice) override {
90d1cff36eSAndi Drebes     lattice->getValue().clear();
91d1cff36eSAndi Drebes   }
9232a4e3fcSOleksandr "Alex" Zinenko 
9332a4e3fcSOleksandr "Alex" Zinenko private:
9432a4e3fcSOleksandr "Alex" Zinenko   bool assumeFuncWrites;
95b2b7efb9SAlex Zinenko };
96b2b7efb9SAlex Zinenko 
9715e915a4SIvan Butygin LogicalResult
9815e915a4SIvan Butygin WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
99b2b7efb9SAlex Zinenko                                   ArrayRef<const WrittenTo *> results) {
100b2b7efb9SAlex Zinenko   if (auto store = dyn_cast<memref::StoreOp>(op)) {
101b2b7efb9SAlex Zinenko     SetVector<StringAttr> newWrites;
102b2b7efb9SAlex Zinenko     newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
103d1cff36eSAndi Drebes     propagateIfChanged(operands[0],
104d1cff36eSAndi Drebes                        operands[0]->getValue().addWrites(newWrites));
10515e915a4SIvan Butygin     return success();
106b2b7efb9SAlex Zinenko   } // By default, every result of an op depends on every operand.
107b2b7efb9SAlex Zinenko   for (const WrittenTo *r : results) {
108b2b7efb9SAlex Zinenko     for (WrittenTo *operand : operands) {
109b2b7efb9SAlex Zinenko       meet(operand, *r);
110b2b7efb9SAlex Zinenko     }
111*4b3f251bSdonald chen     addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op));
112b2b7efb9SAlex Zinenko   }
11315e915a4SIvan Butygin   return success();
114b2b7efb9SAlex Zinenko }
115b2b7efb9SAlex Zinenko 
116b2b7efb9SAlex Zinenko void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
117b2b7efb9SAlex Zinenko   // Mark branch operands as "brancharg%d", with %d the operand number.
118b2b7efb9SAlex Zinenko   WrittenTo *lattice = getLatticeElement(operand.get());
119b2b7efb9SAlex Zinenko   SetVector<StringAttr> newWrites;
120b2b7efb9SAlex Zinenko   newWrites.insert(
121b2b7efb9SAlex Zinenko       StringAttr::get(operand.getOwner()->getContext(),
122b2b7efb9SAlex Zinenko                       "brancharg" + Twine(operand.getOperandNumber())));
123d1cff36eSAndi Drebes   propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
124b2b7efb9SAlex Zinenko }
125b2b7efb9SAlex Zinenko 
126232f8eadSSrishti Srivastava void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
127232f8eadSSrishti Srivastava   // Mark call operands as "callarg%d", with %d the operand number.
128232f8eadSSrishti Srivastava   WrittenTo *lattice = getLatticeElement(operand.get());
129232f8eadSSrishti Srivastava   SetVector<StringAttr> newWrites;
130232f8eadSSrishti Srivastava   newWrites.insert(
131232f8eadSSrishti Srivastava       StringAttr::get(operand.getOwner()->getContext(),
132232f8eadSSrishti Srivastava                       "callarg" + Twine(operand.getOperandNumber())));
133d1cff36eSAndi Drebes   propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
134232f8eadSSrishti Srivastava }
135232f8eadSSrishti Srivastava 
13632a4e3fcSOleksandr "Alex" Zinenko void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
13732a4e3fcSOleksandr "Alex" Zinenko                                           ArrayRef<WrittenTo *> operands,
13832a4e3fcSOleksandr "Alex" Zinenko                                           ArrayRef<const WrittenTo *> results) {
13932a4e3fcSOleksandr "Alex" Zinenko   if (!assumeFuncWrites) {
14032a4e3fcSOleksandr "Alex" Zinenko     return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
14132a4e3fcSOleksandr "Alex" Zinenko                                                              results);
14232a4e3fcSOleksandr "Alex" Zinenko   }
14332a4e3fcSOleksandr "Alex" Zinenko 
14432a4e3fcSOleksandr "Alex" Zinenko   for (WrittenTo *lattice : operands) {
14532a4e3fcSOleksandr "Alex" Zinenko     SetVector<StringAttr> newWrites;
14632a4e3fcSOleksandr "Alex" Zinenko     StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
14732a4e3fcSOleksandr "Alex" Zinenko     if (!name) {
14832a4e3fcSOleksandr "Alex" Zinenko       name = StringAttr::get(call->getContext(),
14932a4e3fcSOleksandr "Alex" Zinenko                              call.getOperation()->getName().getStringRef());
15032a4e3fcSOleksandr "Alex" Zinenko     }
15132a4e3fcSOleksandr "Alex" Zinenko     newWrites.insert(name);
152d1cff36eSAndi Drebes     propagateIfChanged(lattice, lattice->getValue().addWrites(newWrites));
15332a4e3fcSOleksandr "Alex" Zinenko   }
15432a4e3fcSOleksandr "Alex" Zinenko }
15532a4e3fcSOleksandr "Alex" Zinenko 
156b2b7efb9SAlex Zinenko } // end anonymous namespace
157b2b7efb9SAlex Zinenko 
158b2b7efb9SAlex Zinenko namespace {
159b2b7efb9SAlex Zinenko struct TestWrittenToPass
160b2b7efb9SAlex Zinenko     : public PassWrapper<TestWrittenToPass, OperationPass<>> {
161b2b7efb9SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
162b2b7efb9SAlex Zinenko 
16332a4e3fcSOleksandr "Alex" Zinenko   TestWrittenToPass() = default;
16432a4e3fcSOleksandr "Alex" Zinenko   TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
16532a4e3fcSOleksandr "Alex" Zinenko     interprocedural = other.interprocedural;
16632a4e3fcSOleksandr "Alex" Zinenko     assumeFuncWrites = other.assumeFuncWrites;
16732a4e3fcSOleksandr "Alex" Zinenko   }
16832a4e3fcSOleksandr "Alex" Zinenko 
169b2b7efb9SAlex Zinenko   StringRef getArgument() const override { return "test-written-to"; }
170b2b7efb9SAlex Zinenko 
17132a4e3fcSOleksandr "Alex" Zinenko   Option<bool> interprocedural{
17232a4e3fcSOleksandr "Alex" Zinenko       *this, "interprocedural", llvm::cl::init(true),
17332a4e3fcSOleksandr "Alex" Zinenko       llvm::cl::desc("perform interprocedural analysis")};
17432a4e3fcSOleksandr "Alex" Zinenko   Option<bool> assumeFuncWrites{
17532a4e3fcSOleksandr "Alex" Zinenko       *this, "assume-func-writes", llvm::cl::init(false),
17632a4e3fcSOleksandr "Alex" Zinenko       llvm::cl::desc(
17732a4e3fcSOleksandr "Alex" Zinenko           "assume external functions have write effect on all arguments")};
17832a4e3fcSOleksandr "Alex" Zinenko 
179b2b7efb9SAlex Zinenko   void runOnOperation() override {
180b2b7efb9SAlex Zinenko     Operation *op = getOperation();
181b2b7efb9SAlex Zinenko 
182b2b7efb9SAlex Zinenko     SymbolTableCollection symbolTable;
183b2b7efb9SAlex Zinenko 
18432a4e3fcSOleksandr "Alex" Zinenko     DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
185b2b7efb9SAlex Zinenko     solver.load<DeadCodeAnalysis>();
186b2b7efb9SAlex Zinenko     solver.load<SparseConstantPropagation>();
18732a4e3fcSOleksandr "Alex" Zinenko     solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
188b2b7efb9SAlex Zinenko     if (failed(solver.initializeAndRun(op)))
189b2b7efb9SAlex Zinenko       return signalPassFailure();
190b2b7efb9SAlex Zinenko 
191b2b7efb9SAlex Zinenko     raw_ostream &os = llvm::outs();
192b2b7efb9SAlex Zinenko     op->walk([&](Operation *op) {
193b2b7efb9SAlex Zinenko       auto tag = op->getAttrOfType<StringAttr>("tag");
194b2b7efb9SAlex Zinenko       if (!tag)
195b2b7efb9SAlex Zinenko         return;
196b2b7efb9SAlex Zinenko       os << "test_tag: " << tag.getValue() << ":\n";
197b2b7efb9SAlex Zinenko       for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
198b2b7efb9SAlex Zinenko         const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
199b2b7efb9SAlex Zinenko         assert(writtenTo && "expected a sparse lattice");
200b2b7efb9SAlex Zinenko         os << " operand #" << index << ": ";
201b2b7efb9SAlex Zinenko         writtenTo->print(os);
202b2b7efb9SAlex Zinenko         os << "\n";
203b2b7efb9SAlex Zinenko       }
204b2b7efb9SAlex Zinenko       for (auto [index, operand] : llvm::enumerate(op->getResults())) {
205b2b7efb9SAlex Zinenko         const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
206b2b7efb9SAlex Zinenko         assert(writtenTo && "expected a sparse lattice");
207b2b7efb9SAlex Zinenko         os << " result #" << index << ": ";
208b2b7efb9SAlex Zinenko         writtenTo->print(os);
209b2b7efb9SAlex Zinenko         os << "\n";
210b2b7efb9SAlex Zinenko       }
211b2b7efb9SAlex Zinenko     });
212b2b7efb9SAlex Zinenko   }
213b2b7efb9SAlex Zinenko };
214b2b7efb9SAlex Zinenko } // end anonymous namespace
215b2b7efb9SAlex Zinenko 
216b2b7efb9SAlex Zinenko namespace mlir {
217b2b7efb9SAlex Zinenko namespace test {
218b2b7efb9SAlex Zinenko void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
219b2b7efb9SAlex Zinenko } // end namespace test
220b2b7efb9SAlex Zinenko } // end namespace mlir
221